oxide_sql_sqlite/builder/
upsert.rs

1//! SQLite UPSERT (INSERT ... ON CONFLICT) builder.
2
3use std::marker::PhantomData;
4
5use oxide_sql_core::builder::value::{SqlValue, ToSqlValue};
6
7// Typestate markers
8
9/// Marker: No table specified yet.
10pub struct NoTable;
11/// Marker: Table has been specified.
12pub struct HasTable;
13/// Marker: No values specified yet.
14pub struct NoValues;
15/// Marker: Values have been specified.
16pub struct HasValues;
17/// Marker: No conflict target specified yet.
18pub struct NoConflict;
19/// Marker: Conflict target has been specified.
20pub struct HasConflict;
21
22/// A type-safe UPSERT (INSERT ... ON CONFLICT) builder for SQLite.
23pub struct UpsertBuilder<Table, Values, Conflict> {
24    table: Option<String>,
25    columns: Vec<String>,
26    values: Vec<SqlValue>,
27    conflict_columns: Vec<String>,
28    update_columns: Vec<String>,
29    do_nothing: bool,
30    _state: PhantomData<(Table, Values, Conflict)>,
31}
32
33impl UpsertBuilder<NoTable, NoValues, NoConflict> {
34    /// Creates a new UPSERT builder.
35    #[must_use]
36    pub fn new() -> Self {
37        Self {
38            table: None,
39            columns: vec![],
40            values: vec![],
41            conflict_columns: vec![],
42            update_columns: vec![],
43            do_nothing: false,
44            _state: PhantomData,
45        }
46    }
47}
48
49impl Default for UpsertBuilder<NoTable, NoValues, NoConflict> {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55// Transition: NoTable -> HasTable
56impl<Values, Conflict> UpsertBuilder<NoTable, Values, Conflict> {
57    /// Specifies the table to insert into.
58    #[must_use]
59    pub fn into_table(self, table: &str) -> UpsertBuilder<HasTable, Values, Conflict> {
60        UpsertBuilder {
61            table: Some(String::from(table)),
62            columns: self.columns,
63            values: self.values,
64            conflict_columns: self.conflict_columns,
65            update_columns: self.update_columns,
66            do_nothing: self.do_nothing,
67            _state: PhantomData,
68        }
69    }
70}
71
72// Methods available after specifying table
73impl<Values, Conflict> UpsertBuilder<HasTable, Values, Conflict> {
74    /// Specifies the columns to insert into.
75    #[must_use]
76    pub fn columns(mut self, cols: &[&str]) -> Self {
77        self.columns = cols.iter().map(|s| String::from(*s)).collect();
78        self
79    }
80}
81
82// Transition: NoValues -> HasValues
83impl<Conflict> UpsertBuilder<HasTable, NoValues, Conflict> {
84    /// Adds values to insert.
85    #[must_use]
86    pub fn values<T: ToSqlValue>(
87        self,
88        vals: Vec<T>,
89    ) -> UpsertBuilder<HasTable, HasValues, Conflict> {
90        let sql_values: Vec<SqlValue> = vals.into_iter().map(ToSqlValue::to_sql_value).collect();
91        UpsertBuilder {
92            table: self.table,
93            columns: self.columns,
94            values: sql_values,
95            conflict_columns: self.conflict_columns,
96            update_columns: self.update_columns,
97            do_nothing: self.do_nothing,
98            _state: PhantomData,
99        }
100    }
101}
102
103// Transition: NoConflict -> HasConflict
104impl UpsertBuilder<HasTable, HasValues, NoConflict> {
105    /// Specifies the conflict target columns.
106    #[must_use]
107    pub fn on_conflict(self, cols: &[&str]) -> UpsertBuilder<HasTable, HasValues, HasConflict> {
108        UpsertBuilder {
109            table: self.table,
110            columns: self.columns,
111            values: self.values,
112            conflict_columns: cols.iter().map(|s| String::from(*s)).collect(),
113            update_columns: self.update_columns,
114            do_nothing: self.do_nothing,
115            _state: PhantomData,
116        }
117    }
118}
119
120// Methods available after ON CONFLICT
121impl UpsertBuilder<HasTable, HasValues, HasConflict> {
122    /// Sets DO NOTHING action.
123    #[must_use]
124    pub fn do_nothing(mut self) -> Self {
125        self.do_nothing = true;
126        self
127    }
128
129    /// Sets DO UPDATE with specified columns.
130    #[must_use]
131    pub fn do_update(mut self, cols: &[&str]) -> Self {
132        self.update_columns = cols.iter().map(|s| String::from(*s)).collect();
133        self.do_nothing = false;
134        self
135    }
136
137    /// Builds the UPSERT statement and returns SQL with parameters.
138    #[must_use]
139    pub fn build(self) -> (String, Vec<SqlValue>) {
140        let mut sql = String::from("INSERT INTO ");
141        let mut params = vec![];
142
143        if let Some(ref table) = self.table {
144            sql.push_str(table);
145        }
146
147        if !self.columns.is_empty() {
148            sql.push_str(" (");
149            sql.push_str(&self.columns.join(", "));
150            sql.push(')');
151        }
152
153        sql.push_str(" VALUES (");
154        let placeholders: Vec<&str> = self.values.iter().map(|_| "?").collect();
155        sql.push_str(&placeholders.join(", "));
156        sql.push(')');
157
158        params.extend(self.values.clone());
159
160        sql.push_str(" ON CONFLICT (");
161        sql.push_str(&self.conflict_columns.join(", "));
162        sql.push(')');
163
164        if self.do_nothing {
165            sql.push_str(" DO NOTHING");
166        } else if !self.update_columns.is_empty() {
167            sql.push_str(" DO UPDATE SET ");
168            let updates: Vec<String> = self
169                .update_columns
170                .iter()
171                .map(|col| format!("{col} = excluded.{col}"))
172                .collect();
173            sql.push_str(&updates.join(", "));
174        }
175
176        (sql, params)
177    }
178
179    /// Builds the UPSERT statement and returns only the SQL string.
180    #[must_use]
181    pub fn build_sql(self) -> String {
182        let (sql, _) = self.build();
183        sql
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_upsert_do_nothing() {
193        let (sql, params) = UpsertBuilder::new()
194            .into_table("users")
195            .columns(&["id", "name"])
196            .values(vec![1_i64.to_sql_value(), "Alice".to_sql_value()])
197            .on_conflict(&["id"])
198            .do_nothing()
199            .build();
200
201        assert_eq!(
202            sql,
203            "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT (id) DO NOTHING"
204        );
205        assert_eq!(params.len(), 2);
206    }
207
208    #[test]
209    fn test_upsert_do_update() {
210        let (sql, params) = UpsertBuilder::new()
211            .into_table("users")
212            .columns(&["id", "name", "email"])
213            .values(vec![
214                1_i64.to_sql_value(),
215                "Alice".to_sql_value(),
216                "alice@example.com".to_sql_value(),
217            ])
218            .on_conflict(&["id"])
219            .do_update(&["name", "email"])
220            .build();
221
222        assert_eq!(
223            sql,
224            "INSERT INTO users (id, name, email) VALUES (?, ?, ?) \
225             ON CONFLICT (id) DO UPDATE SET name = excluded.name, email = excluded.email"
226        );
227        assert_eq!(params.len(), 3);
228    }
229
230    #[test]
231    fn test_upsert_composite_key() {
232        let (sql, _) = UpsertBuilder::new()
233            .into_table("user_roles")
234            .columns(&["user_id", "role_id", "granted_at"])
235            .values(vec![
236                1_i64.to_sql_value(),
237                2_i64.to_sql_value(),
238                "2024-01-01".to_sql_value(),
239            ])
240            .on_conflict(&["user_id", "role_id"])
241            .do_update(&["granted_at"])
242            .build();
243
244        assert!(sql.contains("ON CONFLICT (user_id, role_id)"));
245        assert!(sql.contains("DO UPDATE SET granted_at = excluded.granted_at"));
246    }
247
248    #[test]
249    fn test_upsert_sql_injection_prevention() {
250        let malicious = "'; DROP TABLE users; --";
251        let (sql, params) = UpsertBuilder::new()
252            .into_table("users")
253            .columns(&["id", "name"])
254            .values(vec![1_i64.to_sql_value(), malicious.to_sql_value()])
255            .on_conflict(&["id"])
256            .do_update(&["name"])
257            .build();
258
259        // SQL uses parameterized placeholders
260        assert!(sql.contains("VALUES (?, ?)"));
261        // Malicious input is safely stored as parameter
262        assert!(matches!(&params[1], SqlValue::Text(s) if s == malicious));
263    }
264}