oxide_sql_sqlite/builder/
upsert.rs1use std::marker::PhantomData;
4
5use oxide_sql_core::builder::value::{SqlValue, ToSqlValue};
6
7pub struct NoTable;
11pub struct HasTable;
13pub struct NoValues;
15pub struct HasValues;
17pub struct NoConflict;
19pub struct HasConflict;
21
22pub struct UpsertBuilder<Table, Values, Conflict> {
24 table: Option<String>,
25 columns: Vec<String>,
26 values: Vec<SqlValue>,
27 conflict_columns: Vec<String>,
28 update_columns: Vec<String>,
29 do_nothing: bool,
30 _state: PhantomData<(Table, Values, Conflict)>,
31}
32
33impl UpsertBuilder<NoTable, NoValues, NoConflict> {
34 #[must_use]
36 pub fn new() -> Self {
37 Self {
38 table: None,
39 columns: vec![],
40 values: vec![],
41 conflict_columns: vec![],
42 update_columns: vec![],
43 do_nothing: false,
44 _state: PhantomData,
45 }
46 }
47}
48
49impl Default for UpsertBuilder<NoTable, NoValues, NoConflict> {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl<Values, Conflict> UpsertBuilder<NoTable, Values, Conflict> {
57 #[must_use]
59 pub fn into_table(self, table: &str) -> UpsertBuilder<HasTable, Values, Conflict> {
60 UpsertBuilder {
61 table: Some(String::from(table)),
62 columns: self.columns,
63 values: self.values,
64 conflict_columns: self.conflict_columns,
65 update_columns: self.update_columns,
66 do_nothing: self.do_nothing,
67 _state: PhantomData,
68 }
69 }
70}
71
72impl<Values, Conflict> UpsertBuilder<HasTable, Values, Conflict> {
74 #[must_use]
76 pub fn columns(mut self, cols: &[&str]) -> Self {
77 self.columns = cols.iter().map(|s| String::from(*s)).collect();
78 self
79 }
80}
81
82impl<Conflict> UpsertBuilder<HasTable, NoValues, Conflict> {
84 #[must_use]
86 pub fn values<T: ToSqlValue>(
87 self,
88 vals: Vec<T>,
89 ) -> UpsertBuilder<HasTable, HasValues, Conflict> {
90 let sql_values: Vec<SqlValue> = vals.into_iter().map(ToSqlValue::to_sql_value).collect();
91 UpsertBuilder {
92 table: self.table,
93 columns: self.columns,
94 values: sql_values,
95 conflict_columns: self.conflict_columns,
96 update_columns: self.update_columns,
97 do_nothing: self.do_nothing,
98 _state: PhantomData,
99 }
100 }
101}
102
103impl UpsertBuilder<HasTable, HasValues, NoConflict> {
105 #[must_use]
107 pub fn on_conflict(self, cols: &[&str]) -> UpsertBuilder<HasTable, HasValues, HasConflict> {
108 UpsertBuilder {
109 table: self.table,
110 columns: self.columns,
111 values: self.values,
112 conflict_columns: cols.iter().map(|s| String::from(*s)).collect(),
113 update_columns: self.update_columns,
114 do_nothing: self.do_nothing,
115 _state: PhantomData,
116 }
117 }
118}
119
120impl UpsertBuilder<HasTable, HasValues, HasConflict> {
122 #[must_use]
124 pub fn do_nothing(mut self) -> Self {
125 self.do_nothing = true;
126 self
127 }
128
129 #[must_use]
131 pub fn do_update(mut self, cols: &[&str]) -> Self {
132 self.update_columns = cols.iter().map(|s| String::from(*s)).collect();
133 self.do_nothing = false;
134 self
135 }
136
137 #[must_use]
139 pub fn build(self) -> (String, Vec<SqlValue>) {
140 let mut sql = String::from("INSERT INTO ");
141 let mut params = vec![];
142
143 if let Some(ref table) = self.table {
144 sql.push_str(table);
145 }
146
147 if !self.columns.is_empty() {
148 sql.push_str(" (");
149 sql.push_str(&self.columns.join(", "));
150 sql.push(')');
151 }
152
153 sql.push_str(" VALUES (");
154 let placeholders: Vec<&str> = self.values.iter().map(|_| "?").collect();
155 sql.push_str(&placeholders.join(", "));
156 sql.push(')');
157
158 params.extend(self.values.clone());
159
160 sql.push_str(" ON CONFLICT (");
161 sql.push_str(&self.conflict_columns.join(", "));
162 sql.push(')');
163
164 if self.do_nothing {
165 sql.push_str(" DO NOTHING");
166 } else if !self.update_columns.is_empty() {
167 sql.push_str(" DO UPDATE SET ");
168 let updates: Vec<String> = self
169 .update_columns
170 .iter()
171 .map(|col| format!("{col} = excluded.{col}"))
172 .collect();
173 sql.push_str(&updates.join(", "));
174 }
175
176 (sql, params)
177 }
178
179 #[must_use]
181 pub fn build_sql(self) -> String {
182 let (sql, _) = self.build();
183 sql
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_upsert_do_nothing() {
193 let (sql, params) = UpsertBuilder::new()
194 .into_table("users")
195 .columns(&["id", "name"])
196 .values(vec![1_i64.to_sql_value(), "Alice".to_sql_value()])
197 .on_conflict(&["id"])
198 .do_nothing()
199 .build();
200
201 assert_eq!(
202 sql,
203 "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT (id) DO NOTHING"
204 );
205 assert_eq!(params.len(), 2);
206 }
207
208 #[test]
209 fn test_upsert_do_update() {
210 let (sql, params) = UpsertBuilder::new()
211 .into_table("users")
212 .columns(&["id", "name", "email"])
213 .values(vec![
214 1_i64.to_sql_value(),
215 "Alice".to_sql_value(),
216 "alice@example.com".to_sql_value(),
217 ])
218 .on_conflict(&["id"])
219 .do_update(&["name", "email"])
220 .build();
221
222 assert_eq!(
223 sql,
224 "INSERT INTO users (id, name, email) VALUES (?, ?, ?) \
225 ON CONFLICT (id) DO UPDATE SET name = excluded.name, email = excluded.email"
226 );
227 assert_eq!(params.len(), 3);
228 }
229
230 #[test]
231 fn test_upsert_composite_key() {
232 let (sql, _) = UpsertBuilder::new()
233 .into_table("user_roles")
234 .columns(&["user_id", "role_id", "granted_at"])
235 .values(vec![
236 1_i64.to_sql_value(),
237 2_i64.to_sql_value(),
238 "2024-01-01".to_sql_value(),
239 ])
240 .on_conflict(&["user_id", "role_id"])
241 .do_update(&["granted_at"])
242 .build();
243
244 assert!(sql.contains("ON CONFLICT (user_id, role_id)"));
245 assert!(sql.contains("DO UPDATE SET granted_at = excluded.granted_at"));
246 }
247
248 #[test]
249 fn test_upsert_sql_injection_prevention() {
250 let malicious = "'; DROP TABLE users; --";
251 let (sql, params) = UpsertBuilder::new()
252 .into_table("users")
253 .columns(&["id", "name"])
254 .values(vec![1_i64.to_sql_value(), malicious.to_sql_value()])
255 .on_conflict(&["id"])
256 .do_update(&["name"])
257 .build();
258
259 assert!(sql.contains("VALUES (?, ?)"));
261 assert!(matches!(¶ms[1], SqlValue::Text(s) if s == malicious));
263 }
264}