1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{Attribute, Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Type, parse_macro_input};
10
11#[proc_macro_derive(Table, attributes(table, column))]
36pub fn derive_table(input: TokenStream) -> TokenStream {
37 let input = parse_macro_input!(input as DeriveInput);
38 derive_table_impl(input)
39 .unwrap_or_else(|e| e.to_compile_error())
40 .into()
41}
42
43fn derive_table_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
44 let struct_name = &input.ident;
45 let table_name = get_table_name(&input.attrs, struct_name)?;
46
47 let fields = match &input.data {
48 Data::Struct(data) => match &data.fields {
49 Fields::Named(fields) => &fields.named,
50 _ => {
51 return Err(syn::Error::new_spanned(
52 &input,
53 "Table derive only supports structs with named fields",
54 ));
55 }
56 },
57 _ => {
58 return Err(syn::Error::new_spanned(
59 &input,
60 "Table derive only supports structs",
61 ));
62 }
63 };
64
65 let mut column_infos: Vec<ColumnInfo> = Vec::new();
67 for field in fields {
68 let field_name = field.ident.as_ref().unwrap();
69 let field_type = &field.ty;
70 let column_attrs = parse_column_attrs(&field.attrs)?;
71
72 column_infos.push(ColumnInfo {
73 field_name: field_name.clone(),
74 field_type: field_type.clone(),
75 column_name: column_attrs.name.unwrap_or_else(|| field_name.to_string()),
76 is_primary_key: column_attrs.primary_key,
77 is_nullable: column_attrs.nullable,
78 is_unique: column_attrs.unique,
79 is_autoincrement: column_attrs.autoincrement,
80 default_expr: column_attrs.default_expr,
81 });
82 }
83
84 let column_type_names: Vec<Ident> = column_infos
86 .iter()
87 .map(|c| format_ident!("{}", to_pascal_case(&c.field_name.to_string())))
88 .collect();
89
90 let table_struct_name = format_ident!("{}Table", struct_name);
92 let columns_mod_name = format_ident!("{}Columns", struct_name);
93
94 let column_structs: Vec<TokenStream2> = column_infos
96 .iter()
97 .zip(column_type_names.iter())
98 .map(|(info, type_name)| {
99 let column_name = &info.column_name;
100 let field_type = &info.field_type;
101 let is_nullable = info.is_nullable;
102 let is_primary_key = info.is_primary_key;
103
104 quote! {
105 #[derive(Debug, Clone, Copy)]
107 pub struct #type_name;
108
109 impl ::oxide_sql_core::schema::Column for #type_name {
110 type Table = super::#table_struct_name;
111 type Type = #field_type;
112
113 const NAME: &'static str = #column_name;
114 const NULLABLE: bool = #is_nullable;
115 const PRIMARY_KEY: bool = #is_primary_key;
116 }
117
118 impl ::oxide_sql_core::schema::TypedColumn<#field_type> for #type_name {}
119 }
120 })
121 .collect();
122
123 let column_accessors: Vec<TokenStream2> = column_infos
125 .iter()
126 .zip(column_type_names.iter())
127 .map(|(info, type_name)| {
128 let method_name = &info.field_name;
129 quote! {
130 #[inline]
132 pub const fn #method_name() -> #columns_mod_name::#type_name {
133 #columns_mod_name::#type_name
134 }
135 }
136 })
137 .collect();
138
139 let all_column_names: Vec<&str> = column_infos
141 .iter()
142 .map(|c| c.column_name.as_str())
143 .collect();
144
145 let primary_key_column = column_infos
147 .iter()
148 .find(|c| c.is_primary_key)
149 .map(|c| &c.column_name);
150
151 let primary_key_impl = if let Some(pk) = primary_key_column {
152 quote! {
153 const PRIMARY_KEY: Option<&'static str> = Some(#pk);
154 }
155 } else {
156 quote! {
157 const PRIMARY_KEY: Option<&'static str> = None;
158 }
159 };
160
161 let schema_entries: Vec<TokenStream2> = column_infos
163 .iter()
164 .map(|info| {
165 let col_name = &info.column_name;
166 let field_type = &info.field_type;
167 let rust_type_str = quote!(#field_type).to_string().replace(' ', "");
168 let is_nullable = info.is_nullable;
169 let is_primary_key = info.is_primary_key;
170 let is_unique = info.is_unique;
171 let is_autoincrement = info.is_autoincrement;
172 let default_expr_token = match &info.default_expr {
173 Some(expr) => quote! { Some(#expr) },
174 None => quote! { None },
175 };
176
177 quote! {
178 ::oxide_sql_core::schema::ColumnSchema {
179 name: #col_name,
180 rust_type: #rust_type_str,
181 nullable: #is_nullable,
182 primary_key: #is_primary_key,
183 unique: #is_unique,
184 autoincrement: #is_autoincrement,
185 default_expr: #default_expr_token,
186 }
187 }
188 })
189 .collect();
190
191 let expanded = quote! {
192 #[allow(non_snake_case)]
194 pub mod #columns_mod_name {
195 #(#column_structs)*
196 }
197
198 #[derive(Debug, Clone, Copy)]
200 pub struct #table_struct_name;
201
202 impl ::oxide_sql_core::schema::Table for #table_struct_name {
203 type Row = #struct_name;
204
205 const NAME: &'static str = #table_name;
206 const COLUMNS: &'static [&'static str] = &[#(#all_column_names),*];
207 #primary_key_impl
208 }
209
210 impl ::oxide_sql_core::schema::TableSchema
211 for #table_struct_name
212 {
213 const SCHEMA: &'static [
214 ::oxide_sql_core::schema::ColumnSchema
215 ] = &[
216 #(#schema_entries),*
217 ];
218 }
219
220 impl #table_struct_name {
221 #[inline]
223 pub const fn table_name() -> &'static str {
224 #table_name
225 }
226
227 #(#column_accessors)*
228 }
229
230 impl #struct_name {
231 pub fn table() -> #table_struct_name {
233 #table_struct_name
234 }
235
236 #(#column_accessors)*
237 }
238 };
239
240 Ok(expanded)
241}
242
243struct ColumnInfo {
244 field_name: Ident,
245 field_type: Type,
246 column_name: String,
247 is_primary_key: bool,
248 is_nullable: bool,
249 is_unique: bool,
250 is_autoincrement: bool,
251 default_expr: Option<String>,
252}
253
254struct ColumnAttrs {
255 name: Option<String>,
256 primary_key: bool,
257 nullable: bool,
258 unique: bool,
259 autoincrement: bool,
260 default_expr: Option<String>,
261}
262
263fn get_table_name(attrs: &[Attribute], struct_name: &Ident) -> syn::Result<String> {
264 for attr in attrs {
265 if attr.path().is_ident("table") {
266 let mut table_name = None;
267 attr.parse_nested_meta(|meta| {
268 if meta.path.is_ident("name") {
269 let value: Expr = meta.value()?.parse()?;
270 if let Expr::Lit(lit) = value {
271 if let Lit::Str(s) = lit.lit {
272 table_name = Some(s.value());
273 }
274 }
275 }
276 Ok(())
277 })?;
278 if let Some(name) = table_name {
279 return Ok(name);
280 }
281 }
282 }
283 Ok(to_snake_case(&struct_name.to_string()))
285}
286
287fn parse_column_attrs(attrs: &[Attribute]) -> syn::Result<ColumnAttrs> {
288 let mut result = ColumnAttrs {
289 name: None,
290 primary_key: false,
291 nullable: false,
292 unique: false,
293 autoincrement: false,
294 default_expr: None,
295 };
296
297 for attr in attrs {
298 if attr.path().is_ident("column") {
299 if matches!(attr.meta, Meta::Path(_)) {
301 continue;
302 }
303
304 attr.parse_nested_meta(|meta| {
305 if meta.path.is_ident("primary_key") {
306 result.primary_key = true;
307 } else if meta.path.is_ident("nullable") {
308 result.nullable = true;
309 } else if meta.path.is_ident("unique") {
310 result.unique = true;
311 } else if meta.path.is_ident("autoincrement") {
312 result.autoincrement = true;
313 } else if meta.path.is_ident("name") {
314 let value: Expr = meta.value()?.parse()?;
315 if let Expr::Lit(lit) = value {
316 if let Lit::Str(s) = lit.lit {
317 result.name = Some(s.value());
318 }
319 }
320 } else if meta.path.is_ident("default") {
321 let value: Expr = meta.value()?.parse()?;
322 if let Expr::Lit(lit) = value {
323 if let Lit::Str(s) = lit.lit {
324 result.default_expr = Some(s.value());
325 }
326 }
327 }
328 Ok(())
329 })?;
330 }
331 }
332
333 Ok(result)
334}
335
336fn to_snake_case(s: &str) -> String {
337 let mut result = String::new();
338 for (i, c) in s.chars().enumerate() {
339 if c.is_uppercase() {
340 if i > 0 {
341 result.push('_');
342 }
343 result.push(c.to_ascii_lowercase());
344 } else {
345 result.push(c);
346 }
347 }
348 result
349}
350
351fn to_pascal_case(s: &str) -> String {
352 let mut result = String::new();
353 let mut capitalize_next = true;
354 for c in s.chars() {
355 if c == '_' {
356 capitalize_next = true;
357 } else if capitalize_next {
358 result.push(c.to_ascii_uppercase());
359 capitalize_next = false;
360 } else {
361 result.push(c);
362 }
363 }
364 result
365}