oxide_sql_core/lexer/
tokenizer.rs

1//! SQL Tokenizer implementation.
2
3use super::{Keyword, Span, Token, TokenKind};
4
5/// A lexer that tokenizes SQL input.
6pub struct Lexer<'a> {
7    /// The input source code.
8    input: &'a str,
9    /// The current byte position.
10    pos: usize,
11    /// The byte position of the start of the current token.
12    start: usize,
13}
14
15impl<'a> Lexer<'a> {
16    /// Creates a new lexer for the given input.
17    #[must_use]
18    pub const fn new(input: &'a str) -> Self {
19        Self {
20            input,
21            pos: 0,
22            start: 0,
23        }
24    }
25
26    /// Returns the current character without advancing.
27    fn peek(&self) -> Option<char> {
28        self.input[self.pos..].chars().next()
29    }
30
31    /// Returns the next character without advancing.
32    fn peek_next(&self) -> Option<char> {
33        let mut chars = self.input[self.pos..].chars();
34        chars.next();
35        chars.next()
36    }
37
38    /// Advances to the next character and returns it.
39    fn advance(&mut self) -> Option<char> {
40        let c = self.peek()?;
41        self.pos += c.len_utf8();
42        Some(c)
43    }
44
45    /// Skips whitespace and comments.
46    fn skip_whitespace_and_comments(&mut self) {
47        loop {
48            // Skip whitespace
49            while self.peek().is_some_and(|c| c.is_whitespace()) {
50                self.advance();
51            }
52
53            // Skip single-line comments (-- ...)
54            if self.peek() == Some('-') && self.peek_next() == Some('-') {
55                self.advance(); // -
56                self.advance(); // -
57                while self.peek().is_some_and(|c| c != '\n') {
58                    self.advance();
59                }
60                continue;
61            }
62
63            // Skip multi-line comments (/* ... */)
64            if self.peek() == Some('/') && self.peek_next() == Some('*') {
65                self.advance(); // /
66                self.advance(); // *
67                loop {
68                    match self.advance() {
69                        Some('*') if self.peek() == Some('/') => {
70                            self.advance();
71                            break;
72                        }
73                        None => break,
74                        _ => {}
75                    }
76                }
77                continue;
78            }
79
80            break;
81        }
82    }
83
84    /// Creates a span from start to current position.
85    fn make_span(&self) -> Span {
86        Span::new(self.start, self.pos)
87    }
88
89    /// Creates a token with the current span.
90    fn make_token(&self, kind: TokenKind) -> Token {
91        Token::new(kind, self.make_span())
92    }
93
94    /// Scans an identifier or keyword.
95    fn scan_identifier(&mut self) -> Token {
96        while self.peek().is_some_and(|c| c.is_alphanumeric() || c == '_') {
97            self.advance();
98        }
99
100        let text = &self.input[self.start..self.pos];
101
102        // Check if it's a keyword
103        if let Some(keyword) = Keyword::from_str(text) {
104            self.make_token(TokenKind::Keyword(keyword))
105        } else {
106            self.make_token(TokenKind::Identifier(String::from(text)))
107        }
108    }
109
110    /// Scans a quoted identifier (e.g., "column name" or `column name`).
111    fn scan_quoted_identifier(&mut self, quote: char) -> Token {
112        self.advance(); // consume opening quote
113        let content_start = self.pos;
114
115        loop {
116            match self.peek() {
117                Some(c) if c == quote => {
118                    // Check for escaped quote (double quote)
119                    if self.peek_next() == Some(quote) {
120                        self.advance();
121                        self.advance();
122                    } else {
123                        break;
124                    }
125                }
126                Some(_) => {
127                    self.advance();
128                }
129                None => {
130                    return self.make_token(TokenKind::Error(String::from(
131                        "Unterminated quoted identifier",
132                    )));
133                }
134            }
135        }
136
137        let content = &self.input[content_start..self.pos];
138        self.advance(); // consume closing quote
139
140        // Handle escaped quotes
141        let unescaped = content.replace(&format!("{quote}{quote}"), &quote.to_string());
142        self.make_token(TokenKind::Identifier(unescaped))
143    }
144
145    /// Scans a number (integer or float).
146    fn scan_number(&mut self) -> Token {
147        let mut is_float = false;
148
149        while self.peek().is_some_and(|c| c.is_ascii_digit()) {
150            self.advance();
151        }
152
153        // Check for decimal point
154        if self.peek() == Some('.') && self.peek_next().is_some_and(|c| c.is_ascii_digit()) {
155            is_float = true;
156            self.advance(); // consume .
157            while self.peek().is_some_and(|c| c.is_ascii_digit()) {
158                self.advance();
159            }
160        }
161
162        // Check for exponent
163        if self.peek().is_some_and(|c| c == 'e' || c == 'E') {
164            is_float = true;
165            self.advance(); // consume e/E
166            if self.peek().is_some_and(|c| c == '+' || c == '-') {
167                self.advance();
168            }
169            while self.peek().is_some_and(|c| c.is_ascii_digit()) {
170                self.advance();
171            }
172        }
173
174        let text = &self.input[self.start..self.pos];
175
176        if is_float {
177            match text.parse::<f64>() {
178                Ok(f) => self.make_token(TokenKind::Float(f)),
179                Err(e) => self.make_token(TokenKind::Error(format!("Invalid float: {e}"))),
180            }
181        } else {
182            match text.parse::<i64>() {
183                Ok(i) => self.make_token(TokenKind::Integer(i)),
184                Err(e) => self.make_token(TokenKind::Error(format!("Invalid integer: {e}"))),
185            }
186        }
187    }
188
189    /// Scans a string literal.
190    fn scan_string(&mut self, quote: char) -> Token {
191        self.advance(); // consume opening quote
192        let mut value = String::new();
193
194        loop {
195            match self.peek() {
196                Some(c) if c == quote => {
197                    // Check for escaped quote (double quote)
198                    if self.peek_next() == Some(quote) {
199                        value.push(quote);
200                        self.advance();
201                        self.advance();
202                    } else {
203                        break;
204                    }
205                }
206                Some(c) => {
207                    value.push(c);
208                    self.advance();
209                }
210                None => {
211                    return self.make_token(TokenKind::Error(String::from(
212                        "Unterminated string literal",
213                    )));
214                }
215            }
216        }
217
218        self.advance(); // consume closing quote
219        self.make_token(TokenKind::String(value))
220    }
221
222    /// Scans a blob literal (X'...' or x'...').
223    fn scan_blob(&mut self) -> Token {
224        self.advance(); // consume X/x
225        if self.peek() != Some('\'') {
226            return self.scan_identifier();
227        }
228        self.advance(); // consume opening quote
229
230        let mut bytes = Vec::new();
231        let mut hex_chars = String::new();
232
233        loop {
234            match self.peek() {
235                Some('\'') => break,
236                Some(c) if c.is_ascii_hexdigit() => {
237                    hex_chars.push(c);
238                    self.advance();
239
240                    if hex_chars.len() == 2 {
241                        if let Ok(byte) = u8::from_str_radix(&hex_chars, 16) {
242                            bytes.push(byte);
243                        }
244                        hex_chars.clear();
245                    }
246                }
247                Some(c) if c.is_whitespace() => {
248                    self.advance();
249                }
250                Some(_) => {
251                    return self.make_token(TokenKind::Error(String::from(
252                        "Invalid character in blob literal",
253                    )));
254                }
255                None => {
256                    return self
257                        .make_token(TokenKind::Error(String::from("Unterminated blob literal")));
258                }
259            }
260        }
261
262        if !hex_chars.is_empty() {
263            return self.make_token(TokenKind::Error(String::from(
264                "Odd number of hex digits in blob literal",
265            )));
266        }
267
268        self.advance(); // consume closing quote
269        self.make_token(TokenKind::Blob(bytes))
270    }
271
272    /// Scans the next token.
273    #[must_use]
274    pub fn next_token(&mut self) -> Token {
275        self.skip_whitespace_and_comments();
276        self.start = self.pos;
277
278        let c = match self.advance() {
279            Some(c) => c,
280            None => return self.make_token(TokenKind::Eof),
281        };
282
283        match c {
284            // Single-character tokens
285            '(' => self.make_token(TokenKind::LeftParen),
286            ')' => self.make_token(TokenKind::RightParen),
287            '[' => self.make_token(TokenKind::LeftBracket),
288            ']' => self.make_token(TokenKind::RightBracket),
289            ',' => self.make_token(TokenKind::Comma),
290            ';' => self.make_token(TokenKind::Semicolon),
291            '+' => self.make_token(TokenKind::Plus),
292            '-' => self.make_token(TokenKind::Minus),
293            '*' => self.make_token(TokenKind::Star),
294            '/' => self.make_token(TokenKind::Slash),
295            '%' => self.make_token(TokenKind::Percent),
296            '~' => self.make_token(TokenKind::BitNot),
297            '?' => self.make_token(TokenKind::Question),
298            '@' => self.make_token(TokenKind::At),
299
300            // Potentially multi-character tokens
301            '.' => self.make_token(TokenKind::Dot),
302            ':' => {
303                if self.peek() == Some(':') {
304                    self.advance();
305                    self.make_token(TokenKind::DoubleColon)
306                } else {
307                    self.make_token(TokenKind::Colon)
308                }
309            }
310            '=' => self.make_token(TokenKind::Eq),
311            '<' => {
312                if self.peek() == Some('=') {
313                    self.advance();
314                    self.make_token(TokenKind::LtEq)
315                } else if self.peek() == Some('>') {
316                    self.advance();
317                    self.make_token(TokenKind::NotEq)
318                } else if self.peek() == Some('<') {
319                    self.advance();
320                    self.make_token(TokenKind::LeftShift)
321                } else {
322                    self.make_token(TokenKind::Lt)
323                }
324            }
325            '>' => {
326                if self.peek() == Some('=') {
327                    self.advance();
328                    self.make_token(TokenKind::GtEq)
329                } else if self.peek() == Some('>') {
330                    self.advance();
331                    self.make_token(TokenKind::RightShift)
332                } else {
333                    self.make_token(TokenKind::Gt)
334                }
335            }
336            '!' => {
337                if self.peek() == Some('=') {
338                    self.advance();
339                    self.make_token(TokenKind::NotEq)
340                } else {
341                    self.make_token(TokenKind::Error(String::from("Unexpected character: !")))
342                }
343            }
344            '|' => {
345                if self.peek() == Some('|') {
346                    self.advance();
347                    self.make_token(TokenKind::Concat)
348                } else {
349                    self.make_token(TokenKind::BitOr)
350                }
351            }
352            '&' => self.make_token(TokenKind::BitAnd),
353
354            // String literals
355            '\'' => {
356                self.pos = self.start; // Reset position to scan from quote
357                self.scan_string('\'')
358            }
359
360            // Quoted identifiers
361            '"' => {
362                self.pos = self.start;
363                self.scan_quoted_identifier('"')
364            }
365            '`' => {
366                self.pos = self.start;
367                self.scan_quoted_identifier('`')
368            }
369
370            // Blob literals
371            'X' | 'x' if self.peek() == Some('\'') => {
372                self.pos = self.start;
373                self.scan_blob()
374            }
375
376            // Numbers
377            c if c.is_ascii_digit() => {
378                self.pos = self.start;
379                self.scan_number()
380            }
381
382            // Identifiers and keywords
383            c if c.is_alphabetic() || c == '_' => {
384                self.pos = self.start;
385                self.scan_identifier()
386            }
387
388            _ => self.make_token(TokenKind::Error(format!("Unexpected character: {c}"))),
389        }
390    }
391
392    /// Tokenizes the entire input and returns all tokens.
393    #[must_use]
394    pub fn tokenize(&mut self) -> Vec<Token> {
395        let mut tokens = Vec::new();
396        loop {
397            let token = self.next_token();
398            let is_eof = token.is_eof();
399            tokens.push(token);
400            if is_eof {
401                break;
402            }
403        }
404        tokens
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    fn tokenize(input: &str) -> Vec<Token> {
413        Lexer::new(input).tokenize()
414    }
415
416    fn token_kinds(input: &str) -> Vec<TokenKind> {
417        tokenize(input).into_iter().map(|t| t.kind).collect()
418    }
419
420    #[test]
421    fn test_empty_input() {
422        let tokens = tokenize("");
423        assert_eq!(tokens.len(), 1);
424        assert!(matches!(tokens[0].kind, TokenKind::Eof));
425    }
426
427    #[test]
428    fn test_whitespace_only() {
429        let tokens = tokenize("   \n\t  ");
430        assert_eq!(tokens.len(), 1);
431        assert!(matches!(tokens[0].kind, TokenKind::Eof));
432    }
433
434    #[test]
435    fn test_single_line_comment() {
436        assert_eq!(
437            token_kinds("SELECT -- comment\nFROM"),
438            vec![
439                TokenKind::Keyword(Keyword::Select),
440                TokenKind::Keyword(Keyword::From),
441                TokenKind::Eof,
442            ]
443        );
444    }
445
446    #[test]
447    fn test_multi_line_comment() {
448        assert_eq!(
449            token_kinds("SELECT /* comment */ FROM"),
450            vec![
451                TokenKind::Keyword(Keyword::Select),
452                TokenKind::Keyword(Keyword::From),
453                TokenKind::Eof,
454            ]
455        );
456    }
457
458    #[test]
459    fn test_keywords() {
460        assert_eq!(
461            token_kinds("SELECT FROM WHERE"),
462            vec![
463                TokenKind::Keyword(Keyword::Select),
464                TokenKind::Keyword(Keyword::From),
465                TokenKind::Keyword(Keyword::Where),
466                TokenKind::Eof,
467            ]
468        );
469    }
470
471    #[test]
472    fn test_keywords_case_insensitive() {
473        assert_eq!(
474            token_kinds("select FROM wHeRe"),
475            vec![
476                TokenKind::Keyword(Keyword::Select),
477                TokenKind::Keyword(Keyword::From),
478                TokenKind::Keyword(Keyword::Where),
479                TokenKind::Eof,
480            ]
481        );
482    }
483
484    #[test]
485    fn test_identifiers() {
486        assert_eq!(
487            token_kinds("foo bar_baz _qux"),
488            vec![
489                TokenKind::Identifier(String::from("foo")),
490                TokenKind::Identifier(String::from("bar_baz")),
491                TokenKind::Identifier(String::from("_qux")),
492                TokenKind::Eof,
493            ]
494        );
495    }
496
497    #[test]
498    fn test_quoted_identifiers() {
499        assert_eq!(
500            token_kinds("\"column name\" `another`"),
501            vec![
502                TokenKind::Identifier(String::from("column name")),
503                TokenKind::Identifier(String::from("another")),
504                TokenKind::Eof,
505            ]
506        );
507    }
508
509    #[test]
510    fn test_integers() {
511        assert_eq!(
512            token_kinds("42 0 123456789"),
513            vec![
514                TokenKind::Integer(42),
515                TokenKind::Integer(0),
516                TokenKind::Integer(123_456_789),
517                TokenKind::Eof,
518            ]
519        );
520    }
521
522    #[test]
523    fn test_floats() {
524        assert_eq!(
525            token_kinds("2.5 0.5 1e10 2.5e-3"),
526            vec![
527                TokenKind::Float(2.5),
528                TokenKind::Float(0.5),
529                TokenKind::Float(1e10),
530                TokenKind::Float(2.5e-3),
531                TokenKind::Eof,
532            ]
533        );
534    }
535
536    #[test]
537    fn test_strings() {
538        assert_eq!(
539            token_kinds("'hello' 'world'"),
540            vec![
541                TokenKind::String(String::from("hello")),
542                TokenKind::String(String::from("world")),
543                TokenKind::Eof,
544            ]
545        );
546    }
547
548    #[test]
549    fn test_string_with_escaped_quote() {
550        assert_eq!(
551            token_kinds("'it''s'"),
552            vec![TokenKind::String(String::from("it's")), TokenKind::Eof,]
553        );
554    }
555
556    #[test]
557    fn test_blob() {
558        let tokens = tokenize("X'48454C4C4F'");
559        assert_eq!(tokens.len(), 2);
560        assert!(
561            matches!(&tokens[0].kind, TokenKind::Blob(b) if b == &[0x48, 0x45, 0x4C, 0x4C, 0x4F])
562        );
563    }
564
565    #[test]
566    fn test_operators() {
567        assert_eq!(
568            token_kinds("+ - * / % = != <> < <= > >="),
569            vec![
570                TokenKind::Plus,
571                TokenKind::Minus,
572                TokenKind::Star,
573                TokenKind::Slash,
574                TokenKind::Percent,
575                TokenKind::Eq,
576                TokenKind::NotEq,
577                TokenKind::NotEq,
578                TokenKind::Lt,
579                TokenKind::LtEq,
580                TokenKind::Gt,
581                TokenKind::GtEq,
582                TokenKind::Eof,
583            ]
584        );
585    }
586
587    #[test]
588    fn test_delimiters() {
589        assert_eq!(
590            token_kinds("( ) [ ] , ; . : ::"),
591            vec![
592                TokenKind::LeftParen,
593                TokenKind::RightParen,
594                TokenKind::LeftBracket,
595                TokenKind::RightBracket,
596                TokenKind::Comma,
597                TokenKind::Semicolon,
598                TokenKind::Dot,
599                TokenKind::Colon,
600                TokenKind::DoubleColon,
601                TokenKind::Eof,
602            ]
603        );
604    }
605
606    #[test]
607    fn test_concat_operator() {
608        assert_eq!(
609            token_kinds("a || b"),
610            vec![
611                TokenKind::Identifier(String::from("a")),
612                TokenKind::Concat,
613                TokenKind::Identifier(String::from("b")),
614                TokenKind::Eof,
615            ]
616        );
617    }
618
619    #[test]
620    fn test_bitwise_operators() {
621        assert_eq!(
622            token_kinds("a & b | c ~ d << e >> f"),
623            vec![
624                TokenKind::Identifier(String::from("a")),
625                TokenKind::BitAnd,
626                TokenKind::Identifier(String::from("b")),
627                TokenKind::BitOr,
628                TokenKind::Identifier(String::from("c")),
629                TokenKind::BitNot,
630                TokenKind::Identifier(String::from("d")),
631                TokenKind::LeftShift,
632                TokenKind::Identifier(String::from("e")),
633                TokenKind::RightShift,
634                TokenKind::Identifier(String::from("f")),
635                TokenKind::Eof,
636            ]
637        );
638    }
639
640    #[test]
641    fn test_simple_select() {
642        let sql = "SELECT id, name FROM users WHERE active = 1";
643        assert_eq!(
644            token_kinds(sql),
645            vec![
646                TokenKind::Keyword(Keyword::Select),
647                TokenKind::Identifier(String::from("id")),
648                TokenKind::Comma,
649                TokenKind::Identifier(String::from("name")),
650                TokenKind::Keyword(Keyword::From),
651                TokenKind::Identifier(String::from("users")),
652                TokenKind::Keyword(Keyword::Where),
653                TokenKind::Identifier(String::from("active")),
654                TokenKind::Eq,
655                TokenKind::Integer(1),
656                TokenKind::Eof,
657            ]
658        );
659    }
660
661    #[test]
662    fn test_span_tracking() {
663        let tokens = tokenize("SELECT id");
664        assert_eq!(tokens[0].span, Span::new(0, 6));
665        assert_eq!(tokens[1].span, Span::new(7, 9));
666    }
667
668    #[test]
669    fn test_parameter_placeholder() {
670        assert_eq!(
671            token_kinds("? ?1 @param :param"),
672            vec![
673                TokenKind::Question,
674                TokenKind::Question,
675                TokenKind::Integer(1),
676                TokenKind::At,
677                TokenKind::Identifier(String::from("param")),
678                TokenKind::Colon,
679                TokenKind::Identifier(String::from("param")),
680                TokenKind::Eof,
681            ]
682        );
683    }
684}