1 module dorm.api.condition;
2 
3 @safe:
4 
5 import std.conv;
6 import std.datetime;
7 import std.sumtype;
8 import std.traits;
9 import std.typecons : Nullable;
10 
11 import dorm.declarative;
12 
13 public import dorm.lib.ffi : FFIValue;
14 import ffi = dorm.lib.ffi;
15 
16 alias UnaryConditionType = ffi.FFIUnaryCondition.Type;
17 alias BinaryConditionType = ffi.FFIBinaryCondition.Type;
18 alias TernaryConditionType = ffi.FFITernaryCondition.Type;
19 
20 struct Condition
21 {
22 	SumType!(
23 		FFIValue,
24 		UnaryCondition,
25 		BinaryCondition,
26 		TernaryCondition,
27 		AndCondition,
28 		OrCondition
29 	) impl;
30 	alias impl this;
31 
32 	this(T)(T value)
33 	{
34 		impl = value;
35 	}
36 
37 	auto opAssign(T)(T value)
38 	{
39 		impl = value;
40 		return this;
41 	}
42 
43 	static Condition and(Condition[] conditions...)
44 	{
45 		return Condition(AndCondition(conditions.dup));
46 	}
47 
48 	static Condition or(Condition[] conditions...)
49 	{
50 		return Condition(OrCondition(conditions.dup));
51 	}
52 
53 	Condition not() const @trusted
54 	{
55 		Condition* c = new Condition();
56 		c.impl = impl;
57 		return Condition(UnaryCondition(UnaryConditionType.Not, c));
58 	}
59 }
60 
61 struct AndCondition
62 {
63 	Condition[] conditions;
64 }
65 
66 struct OrCondition
67 {
68 	Condition[] conditions;
69 }
70 
71 struct UnaryCondition
72 {
73 	UnaryConditionType type;
74 	Condition* condition;
75 }
76 
77 struct BinaryCondition
78 {
79 	BinaryConditionType type;
80 	Condition* lhs;
81 	Condition* rhs;
82 }
83 
84 struct TernaryCondition
85 {
86 	TernaryConditionType type;
87 	Condition* first, second, third;
88 }
89 
90 FFIValue conditionValue(ModelFormat.Field fieldInfo, T)(T c) @trusted
91 {
92 	import dorm.types.relations : ModelRefImpl;
93 
94 	FFIValue ret;
95 	static if (is(T == Nullable!U, U))
96 	{
97 		if (c.isNull)
98 			ret.type = FFIValue.Type.Null;
99 		else
100 			ret = conditionValue!fieldInfo(c.get);
101 	}
102 	else static if (is(T == ModelRefImpl!(id, _TModel, _TSelect), alias id, _TModel, _TSelect))
103 	{
104 		ret = conditionValue!fieldInfo(c.foreignKey);
105 	}
106 	else static if (fieldInfo.type == ModelFormat.Field.DBType.datetime
107 		&& (is(T == long) || is(T == ulong)))
108 	{
109 		ret = conditionValue!fieldInfo(cast(DateTime) SysTime(cast(long)c, UTC()));
110 	}
111 	else static if (is(T == enum))
112 	{
113 		ret.type = FFIValue.Type.String;
114 		static if (is(OriginalType!T == string))
115 			ret.str = ffi.ffi(cast(string)c);
116 		else
117 			ret.str = ffi.ffi(c.to!string); // std.conv : to gives us the enum field name
118 	}
119 	else static if (is(T == typeof(null)))
120 	{
121 		ret.type = FFIValue.Type.Null;
122 	}
123 	else static if (is(T == bool))
124 	{
125 		ret.type = FFIValue.Type.Bool;
126 		ret.boolean = c;
127 	}
128 	else static if (is(T == short))
129 	{
130 		ret.type = FFIValue.Type.I16;
131 		ret.i16 = c;
132 	}
133 	else static if (is(T == int))
134 	{
135 		ret.type = FFIValue.Type.I32;
136 		ret.i32 = c;
137 	}
138 	else static if (isIntegral!T && is(T : long))
139 	{
140 		ret.type = FFIValue.Type.I64;
141 		ret.i64 = c;
142 	}
143 	else static if (is(T == float))
144 	{
145 		ret.type = FFIValue.Type.F32;
146 		ret.f32 = c;
147 	}
148 	else static if (is(T : double))
149 	{
150 		ret.type = FFIValue.Type.F64;
151 		ret.f64 = c;
152 	}
153 	else static if (is(T : const(char)[]))
154 	{
155 		ret.type = FFIValue.Type.String;
156 		ret.str = ffi.ffi(c);
157 	}
158 	else static if (is(T : Date))
159 	{
160 		ret.type = FFIValue.Type.NaiveDate;
161 		ret.naiveDate = ffi.FFIDate(cast(uint)c.day, cast(uint)c.month, cast(int)c.year);
162 	}
163 	else static if (is(T : TimeOfDay))
164 	{
165 		ret.type = FFIValue.Type.NaiveTime;
166 		ret.naiveTime = ffi.FFITime(cast(uint)c.hour, cast(uint)c.minute, cast(uint)c.second);
167 	}
168 	else static if (is(T : DateTime))
169 	{
170 		ret.type = FFIValue.Type.NaiveDateTime;
171 		ret.naiveDateTime = ffi.FFIDateTime(
172 			cast(int)c.year, cast(uint)c.month, cast(uint)c.day,
173 			cast(uint)c.hour, cast(uint)c.minute, cast(uint)c.second
174 		);
175 	}
176 	else static if (is(T : SysTime))
177 	{
178 		ret.type = FFIValue.Type.NaiveDateTime;
179 		auto d = cast(DateTime) c.toUTC;
180 		ret.naiveDateTime = ffi.FFIDateTime(
181 			cast(int)d.year, cast(uint)d.month, cast(uint)d.day,
182 			cast(uint)d.hour, cast(uint)d.minute, cast(uint)d.second
183 		);
184 	}
185 	else static if (is(T == ffi.FFIString))
186 	{
187 		ret.type = FFIValue.Type.String;
188 		ret.str = c;
189 	}
190 	else
191 		static assert(false, text("Unsupported condition value type: ", T.stringof,
192 			"\n\tTried to build this value for column ", fieldInfo.sourceColumn, " in ", fieldInfo.definedAt).idup);
193 	return ret;
194 }
195 
196 FFIValue conditionIdentifier(return string identifier) @safe
197 {
198 	FFIValue ret;
199 	ret.type = FFIValue.Type.Identifier;
200 	(() @trusted {
201 		ret.identifier = ffi.ffi(identifier);
202 	})();
203 	return ret;
204 }
205 
206 FFIValue columnValue(return string table, return string column) @safe
207 {
208 	FFIValue ret;
209 	ret.type = FFIValue.Type.Column;
210 	(() @trusted {
211 		ret.column = ffi.FFIColumn(ffi.ffi(table), ffi.ffi(column));
212 	})();
213 	return ret;
214 }
215 
216 ffi.FFICondition[] makeTree(Condition c) @trusted
217 {
218 	// we store all conditions sequentially in a flat list, this function may be
219 	// run at CTFE, where it can then efficiently be put on the stack for
220 	// building the whole tree pointers. Otherwise everything is closer together
221 	// in memory as well, so we might even get performance improvements in the
222 	// runtime case.
223 	ffi.FFICondition[] ret;
224 
225 	// as `ret` may be moved when resizing, we only store indices at first when
226 	// constructing the tree inside the pointer fields. Afterwards we go through
227 	// all generated items and replace the pointers, which hold list indices as
228 	// pointer values, instead of valid memory locations, and replace these
229 	// indices with the actual memory locations, to allow lookup on the other
230 	// side of the FFI boundary, which expects pointers as children.
231 	static void recurse(ref ffi.FFICondition[] ret, size_t dst, ref Condition c)
232 	{
233 		ffi.FFICondition dstret;
234 		c.match!(
235 			(FFIValue v)
236 			{
237 				dstret.type = ffi.FFICondition.Type.Value;
238 				dstret.value = v;
239 			},
240 			(UnaryCondition v)
241 			{
242 				size_t index = ret.length;
243 				ret.length++;
244 				recurse(ret, index, *v.condition);
245 				dstret.type = ffi.FFICondition.Type.UnaryCondition;
246 				dstret.unaryCondition = ffi.FFIUnaryCondition(v.type,
247 						cast(ffi.FFICondition*)(index));
248 			},
249 			(BinaryCondition v)
250 			{
251 				size_t index = ret.length;
252 				ret.length += 2;
253 				recurse(ret, index, *v.lhs);
254 				recurse(ret, index + 1, *v.rhs);
255 				dstret.type = ffi.FFICondition.Type.BinaryCondition;
256 				dstret.binaryCondition = ffi.FFIBinaryCondition(v.type,
257 						cast(ffi.FFICondition*)(index),
258 						cast(ffi.FFICondition*)(index + 1));
259 			},
260 			(TernaryCondition v)
261 			{
262 				size_t index = ret.length;
263 				ret.length += 3;
264 				recurse(ret, index, *v.first);
265 				recurse(ret, index + 1, *v.second);
266 				recurse(ret, index + 2, *v.third);
267 				dstret.type = ffi.FFICondition.Type.TernaryCondition;
268 				dstret.ternaryCondition = ffi.FFITernaryCondition(v.type,
269 						cast(ffi.FFICondition*)(index),
270 						cast(ffi.FFICondition*)(index + 1),
271 						cast(ffi.FFICondition*)(index + 2));
272 			},
273 			(AndCondition v)
274 			{
275 				size_t start = ret.length;
276 				ret.length += v.conditions.length;
277 				foreach (i, ref c; v.conditions)
278 					recurse(ret, start + i, c);
279 				dstret.type = ffi.FFICondition.Type.Conjunction;
280 				dstret.conjunction.content = cast(ffi.FFICondition*)start;
281 				dstret.conjunction.size = v.conditions.length;
282 			},
283 			(OrCondition v)
284 			{
285 				size_t start = ret.length;
286 				ret.length += v.conditions.length;
287 				foreach (i, ref c; v.conditions)
288 					recurse(ret, start + i, c);
289 				dstret.type = ffi.FFICondition.Type.Disjunction;
290 				dstret.disjunction.content = cast(ffi.FFICondition*)start;
291 				dstret.disjunction.size = v.conditions.length;
292 			}
293 		);
294 		ret[dst] = dstret;
295 	}
296 
297 	ret.length = 1;
298 	recurse(ret, 0, c);
299 
300 	// now fix the pointer values:
301 	foreach (ref fixup; ret)
302 	{
303 		final switch (fixup.type)
304 		{
305 		case ffi.FFICondition.Type.Value: break;
306 		case ffi.FFICondition.Type.UnaryCondition:
307 			fixup.unaryCondition.condition = &ret[cast(size_t)fixup.unaryCondition.condition];
308 			break;
309 		case ffi.FFICondition.Type.BinaryCondition:
310 			fixup.binaryCondition.lhs = &ret[cast(size_t)fixup.binaryCondition.lhs];
311 			fixup.binaryCondition.rhs = &ret[cast(size_t)fixup.binaryCondition.rhs];
312 			break;
313 		case ffi.FFICondition.Type.TernaryCondition:
314 			fixup.ternaryCondition.first = &ret[cast(size_t)fixup.ternaryCondition.first];
315 			fixup.ternaryCondition.second = &ret[cast(size_t)fixup.ternaryCondition.second];
316 			fixup.ternaryCondition.third = &ret[cast(size_t)fixup.ternaryCondition.third];
317 			break;
318 		case ffi.FFICondition.Type.Conjunction:
319 			fixup.conjunction.content = &ret[cast(size_t)fixup.conjunction.content];
320 			break;
321 		case ffi.FFICondition.Type.Disjunction:
322 			fixup.disjunction.content = &ret[cast(size_t)fixup.disjunction.content];
323 			break;
324 		}
325 	}
326 
327 	return ret;
328 }
329 
330 string dumpTree(ffi.FFICondition[] c)
331 {
332 	import std.array : appender;
333 	import std.format : format;
334 
335 	auto query = appender!string;
336 	query ~= "WHERE";
337 	void recurse(ref ffi.FFICondition c) @trusted
338 	{
339 		import std.conv;
340 
341 		final switch (c.type)
342 		{
343 		case ffi.FFICondition.Type.Value:
344 			query ~= " Value(";
345 			final switch (c.value.type)
346 			{
347 				case FFIValue.Type.String: query ~= '`' ~ c.value.str[] ~ '`'; break;
348 				case FFIValue.Type.Identifier: query ~= "ident:" ~ c.value.identifier[]; break;
349 				case FFIValue.Type.Column: query ~= c.value.column.to!string; break;
350 				case FFIValue.Type.Bool: query ~= c.value.boolean.to!string; break;
351 				case FFIValue.Type.I16: query ~= "i16:" ~ c.value.i16.to!string; break;
352 				case FFIValue.Type.I32: query ~= "i32:" ~ c.value.i32.to!string; break;
353 				case FFIValue.Type.I64: query ~= "i64:" ~ c.value.i64.to!string; break;
354 				case FFIValue.Type.F32: query ~= "f32:" ~ c.value.f32.to!string; break;
355 				case FFIValue.Type.F64: query ~= "f64:" ~ c.value.f64.to!string; break;
356 				case FFIValue.Type.Null: query ~= "[null]"; break;
357 				case FFIValue.Type.Binary: query ~= "(binary)"; break;
358 				case FFIValue.Type.NaiveTime:
359 					auto t = c.value.naiveTime;
360 					query ~= format!"%02d:%02d:%02d"(t.hour, t.min, t.sec);
361 					break;
362 				case FFIValue.Type.NaiveDate:
363 					auto d = c.value.naiveDate;
364 					query ~= format!"%04d-%02d-%02d"(d.year, d.month, d.day);
365 					break;
366 				case FFIValue.Type.NaiveDateTime:
367 					auto dt = c.value.naiveDateTime;
368 					query ~= format!"%04d-%02d-%02dT%02d:%02d:%02d"(dt.year, dt.month, dt.day, dt.hour, dt.min, dt.sec);
369 					break;
370 			}
371 			query ~= ")";
372 			break;
373 		case ffi.FFICondition.Type.UnaryCondition:
374 			final switch (c.unaryCondition.type)
375 			{
376 				case UnaryConditionType.Not:
377 					query ~= " NOT";
378 					recurse(*c.unaryCondition.condition);
379 					break;
380 				case UnaryConditionType.Exists:
381 					recurse(*c.unaryCondition.condition);
382 					query ~= " EXISTS";
383 					break;
384 				case UnaryConditionType.NotExists:
385 					recurse(*c.unaryCondition.condition);
386 					query ~= " NOT EXISTS";
387 					break;
388 				case UnaryConditionType.IsNull:
389 					recurse(*c.unaryCondition.condition);
390 					query ~= " IS NULL";
391 					break;
392 				case UnaryConditionType.IsNotNull:
393 					recurse(*c.unaryCondition.condition);
394 					query ~= " IS NOT NULL";
395 					break;
396 			}
397 			break;
398 		case ffi.FFICondition.Type.BinaryCondition:
399 			recurse(*c.binaryCondition.lhs);
400 			final switch (c.binaryCondition.type)
401 			{
402 				case BinaryConditionType.Equals:
403 					query ~= " =";
404 					break;
405 				case BinaryConditionType.NotEquals:
406 					query ~= " !=";
407 					break;
408 				case BinaryConditionType.Greater:
409 					query ~= " >";
410 					break;
411 				case BinaryConditionType.GreaterOrEquals:
412 					query ~= " >=";
413 					break;
414 				case BinaryConditionType.Less:
415 					query ~= " <";
416 					break;
417 				case BinaryConditionType.LessOrEquals:
418 					query ~= " <=";
419 					break;
420 				case BinaryConditionType.Like:
421 					query ~= " LIKE";
422 					break;
423 				case BinaryConditionType.NotLike:
424 					query ~= " NOT LIKE";
425 					break;
426 				case BinaryConditionType.In:
427 					query ~= " IN";
428 					break;
429 				case BinaryConditionType.NotIn:
430 					query ~= " NOT IN";
431 					break;
432 				case BinaryConditionType.Regexp:
433 					query ~= " REGEXP";
434 					break;
435 				case BinaryConditionType.NotRegexp:
436 					query ~= " NOT REGEXP";
437 					break;
438 			}
439 			recurse(*c.binaryCondition.rhs);
440 			break;
441 		case ffi.FFICondition.Type.TernaryCondition:
442 			recurse(*c.ternaryCondition.first);
443 			final switch (c.ternaryCondition.type)
444 			{
445 				case TernaryConditionType.Between:
446 					query ~= " BETWEEN";
447 					break;
448 				case TernaryConditionType.NotBetween:
449 					query ~= " NOT BETWEEN";
450 					break;
451 			}
452 			recurse(*c.ternaryCondition.second);
453 			query ~= " AND";
454 			recurse(*c.ternaryCondition.third);
455 			break;
456 		case ffi.FFICondition.Type.Conjunction:
457 		case ffi.FFICondition.Type.Disjunction:
458 			string op = c.type == ffi.FFICondition.Type.Conjunction ? " AND" : " OR";
459 			query ~= " (";
460 			foreach (i, ref subc; c.conjunction.data)
461 			{
462 				if (i != 0) query ~= op;
463 				recurse(subc);
464 			}
465 			query ~= " )";
466 			break;
467 		}
468 	}
469 	recurse(c[0]);
470 	return query.data;
471 }
472 
473 unittest
474 {
475 	import std.array;
476 
477 	Condition* and(scope Condition[] conditions...)
478 	{
479 		return new Condition(AndCondition(conditions.dup));
480 	}
481 	Condition* or(scope Condition[] conditions...)
482 	{
483 		return new Condition(OrCondition(conditions.dup));
484 	}
485 
486 	Condition* unary(UnaryConditionType t, Condition* c)
487 	{
488 		return new Condition(UnaryCondition(t, c));
489 	}
490 
491 	Condition* binary(Condition* lhs, BinaryConditionType t, Condition* rhs)
492 	{
493 		return new Condition(BinaryCondition(t, lhs, rhs));
494 	}
495 
496 	Condition* ternary(Condition* first, TernaryConditionType t, Condition* second, Condition* third)
497 	{
498 		return new Condition(TernaryCondition(t, first, second, third));
499 	}
500 
501 	Condition* i(string s)
502 	{
503 		return new Condition(conditionIdentifier(s));
504 	}
505 
506 	Condition* v(T)(T value)
507 	{
508 		return new Condition(conditionValue!(ModelFormat.Field.init)(value));
509 	}
510 
511 	auto condition = and(
512 		*binary(i("foo"), BinaryConditionType.Equals, v("wert")),
513 		*binary(i("bar"), BinaryConditionType.Greater, v(5)),
514 		*unary(UnaryConditionType.Not, or(
515 			*binary(i("baz"), BinaryConditionType.Equals, v(1)),
516 			*binary(i("baz"), BinaryConditionType.Equals, v(4)),
517 		))
518 	);
519 
520 	auto tree = makeTree(*condition);
521 	const query = dumpTree(tree);
522 	assert(query == "WHERE ( Value(ident:foo) = Value(`wert`) AND Value(ident:bar) > Value(i32:5) AND "
523 		~ "NOT ( Value(ident:baz) = Value(i32:1) OR Value(ident:baz) = Value(i32:4) ) )");
524 }