oxide_sql_derive/
lib.rs

1//! Derive macros for type-safe SQL table definitions.
2//!
3//! This crate provides the `#[derive(Table)]` macro for defining database tables
4//! with compile-time checked column names.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{Attribute, Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Type, parse_macro_input};
10
11/// Derives the `Table` trait for a struct, generating type-safe column accessors.
12///
13/// # Attributes
14///
15/// - `#[table(name = "table_name")]` - Specifies the SQL table name (optional,
16///   defaults to snake_case of struct name)
17///
18/// # Field Attributes
19///
20/// - `#[column(primary_key)]` - Marks the field as primary key
21/// - `#[column(name = "column_name")]` - Specifies the SQL column name
22///   (optional, defaults to field name)
23/// - `#[column(nullable)]` - Marks the column as nullable
24/// - `#[column(unique)]` - Marks the column as UNIQUE
25/// - `#[column(autoincrement)]` - Marks the column as AUTOINCREMENT
26/// - `#[column(default = "expr")]` - Sets a raw SQL default expression
27///
28/// # Generated Items
29///
30/// For a struct `User`, this macro generates:
31///
32/// - `UserTable` - A type implementing `Table` trait with table metadata
33/// - `UserColumns` - A module containing column types (`Id`, `Name`, etc.)
34/// - Column accessor methods on `UserTable`
35#[proc_macro_derive(Table, attributes(table, column))]
36pub fn derive_table(input: TokenStream) -> TokenStream {
37    let input = parse_macro_input!(input as DeriveInput);
38    derive_table_impl(input)
39        .unwrap_or_else(|e| e.to_compile_error())
40        .into()
41}
42
43fn derive_table_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
44    let struct_name = &input.ident;
45    let table_name = get_table_name(&input.attrs, struct_name)?;
46
47    let fields = match &input.data {
48        Data::Struct(data) => match &data.fields {
49            Fields::Named(fields) => &fields.named,
50            _ => {
51                return Err(syn::Error::new_spanned(
52                    &input,
53                    "Table derive only supports structs with named fields",
54                ));
55            }
56        },
57        _ => {
58            return Err(syn::Error::new_spanned(
59                &input,
60                "Table derive only supports structs",
61            ));
62        }
63    };
64
65    // Collect field information
66    let mut column_infos: Vec<ColumnInfo> = Vec::new();
67    for field in fields {
68        let field_name = field.ident.as_ref().unwrap();
69        let field_type = &field.ty;
70        let column_attrs = parse_column_attrs(&field.attrs)?;
71
72        column_infos.push(ColumnInfo {
73            field_name: field_name.clone(),
74            field_type: field_type.clone(),
75            column_name: column_attrs.name.unwrap_or_else(|| field_name.to_string()),
76            is_primary_key: column_attrs.primary_key,
77            is_nullable: column_attrs.nullable,
78            is_unique: column_attrs.unique,
79            is_autoincrement: column_attrs.autoincrement,
80            default_expr: column_attrs.default_expr,
81        });
82    }
83
84    // Generate column type names (PascalCase)
85    let column_type_names: Vec<Ident> = column_infos
86        .iter()
87        .map(|c| format_ident!("{}", to_pascal_case(&c.field_name.to_string())))
88        .collect();
89
90    // Generate the table struct name
91    let table_struct_name = format_ident!("{}Table", struct_name);
92    let columns_mod_name = format_ident!("{}Columns", struct_name);
93
94    // Generate column structs
95    let column_structs: Vec<TokenStream2> = column_infos
96        .iter()
97        .zip(column_type_names.iter())
98        .map(|(info, type_name)| {
99            let column_name = &info.column_name;
100            let field_type = &info.field_type;
101            let is_nullable = info.is_nullable;
102            let is_primary_key = info.is_primary_key;
103
104            quote! {
105                /// Column type for compile-time checked queries.
106                #[derive(Debug, Clone, Copy)]
107                pub struct #type_name;
108
109                impl ::oxide_sql_core::schema::Column for #type_name {
110                    type Table = super::#table_struct_name;
111                    type Type = #field_type;
112
113                    const NAME: &'static str = #column_name;
114                    const NULLABLE: bool = #is_nullable;
115                    const PRIMARY_KEY: bool = #is_primary_key;
116                }
117
118                impl ::oxide_sql_core::schema::TypedColumn<#field_type> for #type_name {}
119            }
120        })
121        .collect();
122
123    // Generate column accessor methods
124    let column_accessors: Vec<TokenStream2> = column_infos
125        .iter()
126        .zip(column_type_names.iter())
127        .map(|(info, type_name)| {
128            let method_name = &info.field_name;
129            quote! {
130                /// Returns the column type for type-safe queries.
131                #[inline]
132                pub const fn #method_name() -> #columns_mod_name::#type_name {
133                    #columns_mod_name::#type_name
134                }
135            }
136        })
137        .collect();
138
139    // Generate list of all column names
140    let all_column_names: Vec<&str> = column_infos
141        .iter()
142        .map(|c| c.column_name.as_str())
143        .collect();
144
145    // Find primary key column
146    let primary_key_column = column_infos
147        .iter()
148        .find(|c| c.is_primary_key)
149        .map(|c| &c.column_name);
150
151    let primary_key_impl = if let Some(pk) = primary_key_column {
152        quote! {
153            const PRIMARY_KEY: Option<&'static str> = Some(#pk);
154        }
155    } else {
156        quote! {
157            const PRIMARY_KEY: Option<&'static str> = None;
158        }
159    };
160
161    // Generate TableSchema column entries
162    let schema_entries: Vec<TokenStream2> = column_infos
163        .iter()
164        .map(|info| {
165            let col_name = &info.column_name;
166            let field_type = &info.field_type;
167            let rust_type_str = quote!(#field_type).to_string().replace(' ', "");
168            let is_nullable = info.is_nullable;
169            let is_primary_key = info.is_primary_key;
170            let is_unique = info.is_unique;
171            let is_autoincrement = info.is_autoincrement;
172            let default_expr_token = match &info.default_expr {
173                Some(expr) => quote! { Some(#expr) },
174                None => quote! { None },
175            };
176
177            quote! {
178                ::oxide_sql_core::schema::ColumnSchema {
179                    name: #col_name,
180                    rust_type: #rust_type_str,
181                    nullable: #is_nullable,
182                    primary_key: #is_primary_key,
183                    unique: #is_unique,
184                    autoincrement: #is_autoincrement,
185                    default_expr: #default_expr_token,
186                }
187            }
188        })
189        .collect();
190
191    let expanded = quote! {
192        /// Column types for `#struct_name` table.
193        #[allow(non_snake_case)]
194        pub mod #columns_mod_name {
195            #(#column_structs)*
196        }
197
198        /// Table metadata for `#struct_name`.
199        #[derive(Debug, Clone, Copy)]
200        pub struct #table_struct_name;
201
202        impl ::oxide_sql_core::schema::Table for #table_struct_name {
203            type Row = #struct_name;
204
205            const NAME: &'static str = #table_name;
206            const COLUMNS: &'static [&'static str] = &[#(#all_column_names),*];
207            #primary_key_impl
208        }
209
210        impl ::oxide_sql_core::schema::TableSchema
211            for #table_struct_name
212        {
213            const SCHEMA: &'static [
214                ::oxide_sql_core::schema::ColumnSchema
215            ] = &[
216                #(#schema_entries),*
217            ];
218        }
219
220        impl #table_struct_name {
221            /// Returns the table name.
222            #[inline]
223            pub const fn table_name() -> &'static str {
224                #table_name
225            }
226
227            #(#column_accessors)*
228        }
229
230        impl #struct_name {
231            /// Returns the table metadata type.
232            pub fn table() -> #table_struct_name {
233                #table_struct_name
234            }
235
236            #(#column_accessors)*
237        }
238    };
239
240    Ok(expanded)
241}
242
243struct ColumnInfo {
244    field_name: Ident,
245    field_type: Type,
246    column_name: String,
247    is_primary_key: bool,
248    is_nullable: bool,
249    is_unique: bool,
250    is_autoincrement: bool,
251    default_expr: Option<String>,
252}
253
254struct ColumnAttrs {
255    name: Option<String>,
256    primary_key: bool,
257    nullable: bool,
258    unique: bool,
259    autoincrement: bool,
260    default_expr: Option<String>,
261}
262
263fn get_table_name(attrs: &[Attribute], struct_name: &Ident) -> syn::Result<String> {
264    for attr in attrs {
265        if attr.path().is_ident("table") {
266            let mut table_name = None;
267            attr.parse_nested_meta(|meta| {
268                if meta.path.is_ident("name") {
269                    let value: Expr = meta.value()?.parse()?;
270                    if let Expr::Lit(lit) = value {
271                        if let Lit::Str(s) = lit.lit {
272                            table_name = Some(s.value());
273                        }
274                    }
275                }
276                Ok(())
277            })?;
278            if let Some(name) = table_name {
279                return Ok(name);
280            }
281        }
282    }
283    // Default to snake_case of struct name
284    Ok(to_snake_case(&struct_name.to_string()))
285}
286
287fn parse_column_attrs(attrs: &[Attribute]) -> syn::Result<ColumnAttrs> {
288    let mut result = ColumnAttrs {
289        name: None,
290        primary_key: false,
291        nullable: false,
292        unique: false,
293        autoincrement: false,
294        default_expr: None,
295    };
296
297    for attr in attrs {
298        if attr.path().is_ident("column") {
299            // Handle empty attribute like #[column]
300            if matches!(attr.meta, Meta::Path(_)) {
301                continue;
302            }
303
304            attr.parse_nested_meta(|meta| {
305                if meta.path.is_ident("primary_key") {
306                    result.primary_key = true;
307                } else if meta.path.is_ident("nullable") {
308                    result.nullable = true;
309                } else if meta.path.is_ident("unique") {
310                    result.unique = true;
311                } else if meta.path.is_ident("autoincrement") {
312                    result.autoincrement = true;
313                } else if meta.path.is_ident("name") {
314                    let value: Expr = meta.value()?.parse()?;
315                    if let Expr::Lit(lit) = value {
316                        if let Lit::Str(s) = lit.lit {
317                            result.name = Some(s.value());
318                        }
319                    }
320                } else if meta.path.is_ident("default") {
321                    let value: Expr = meta.value()?.parse()?;
322                    if let Expr::Lit(lit) = value {
323                        if let Lit::Str(s) = lit.lit {
324                            result.default_expr = Some(s.value());
325                        }
326                    }
327                }
328                Ok(())
329            })?;
330        }
331    }
332
333    Ok(result)
334}
335
336fn to_snake_case(s: &str) -> String {
337    let mut result = String::new();
338    for (i, c) in s.chars().enumerate() {
339        if c.is_uppercase() {
340            if i > 0 {
341                result.push('_');
342            }
343            result.push(c.to_ascii_lowercase());
344        } else {
345            result.push(c);
346        }
347    }
348    result
349}
350
351fn to_pascal_case(s: &str) -> String {
352    let mut result = String::new();
353    let mut capitalize_next = true;
354    for c in s.chars() {
355        if c == '_' {
356            capitalize_next = true;
357        } else if capitalize_next {
358            result.push(c.to_ascii_uppercase());
359            capitalize_next = false;
360        } else {
361            result.push(c);
362        }
363    }
364    result
365}