Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | //===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===// |
2 | // |
||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
4 | // See https://llvm.org/LICENSE.txt for license information. |
||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
6 | // |
||
7 | //===----------------------------------------------------------------------===// |
||
8 | // |
||
9 | // A buffer for serialized results. |
||
10 | // |
||
11 | //===----------------------------------------------------------------------===// |
||
12 | |||
13 | #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H |
||
14 | #define LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H |
||
15 | |||
16 | #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" |
||
17 | #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h" |
||
18 | #include "llvm/Support/Error.h" |
||
19 | |||
20 | #include <type_traits> |
||
21 | |||
22 | namespace llvm { |
||
23 | namespace orc { |
||
24 | namespace shared { |
||
25 | |||
26 | // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. |
||
27 | union CWrapperFunctionResultDataUnion { |
||
28 | char *ValuePtr; |
||
29 | char Value[sizeof(ValuePtr)]; |
||
30 | }; |
||
31 | |||
32 | // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. |
||
33 | typedef struct { |
||
34 | CWrapperFunctionResultDataUnion Data; |
||
35 | size_t Size; |
||
36 | } CWrapperFunctionResult; |
||
37 | |||
38 | /// C++ wrapper function result: Same as CWrapperFunctionResult but |
||
39 | /// auto-releases memory. |
||
40 | class WrapperFunctionResult { |
||
41 | public: |
||
42 | /// Create a default WrapperFunctionResult. |
||
43 | WrapperFunctionResult() { init(R); } |
||
44 | |||
45 | /// Create a WrapperFunctionResult by taking ownership of a |
||
46 | /// CWrapperFunctionResult. |
||
47 | /// |
||
48 | /// Warning: This should only be used by clients writing wrapper-function |
||
49 | /// caller utilities (like TargetProcessControl). |
||
50 | WrapperFunctionResult(CWrapperFunctionResult R) : R(R) { |
||
51 | // Reset R. |
||
52 | init(R); |
||
53 | } |
||
54 | |||
55 | WrapperFunctionResult(const WrapperFunctionResult &) = delete; |
||
56 | WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; |
||
57 | |||
58 | WrapperFunctionResult(WrapperFunctionResult &&Other) { |
||
59 | init(R); |
||
60 | std::swap(R, Other.R); |
||
61 | } |
||
62 | |||
63 | WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { |
||
64 | WrapperFunctionResult Tmp(std::move(Other)); |
||
65 | std::swap(R, Tmp.R); |
||
66 | return *this; |
||
67 | } |
||
68 | |||
69 | ~WrapperFunctionResult() { |
||
70 | if ((R.Size > sizeof(R.Data.Value)) || |
||
71 | (R.Size == 0 && R.Data.ValuePtr != nullptr)) |
||
72 | free(R.Data.ValuePtr); |
||
73 | } |
||
74 | |||
75 | /// Release ownership of the contained CWrapperFunctionResult. |
||
76 | /// Warning: Do not use -- this method will be removed in the future. It only |
||
77 | /// exists to temporarily support some code that will eventually be moved to |
||
78 | /// the ORC runtime. |
||
79 | CWrapperFunctionResult release() { |
||
80 | CWrapperFunctionResult Tmp; |
||
81 | init(Tmp); |
||
82 | std::swap(R, Tmp); |
||
83 | return Tmp; |
||
84 | } |
||
85 | |||
86 | /// Get a pointer to the data contained in this instance. |
||
87 | char *data() { |
||
88 | assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && |
||
89 | "Cannot get data for out-of-band error value"); |
||
90 | return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; |
||
91 | } |
||
92 | |||
93 | /// Get a const pointer to the data contained in this instance. |
||
94 | const char *data() const { |
||
95 | assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && |
||
96 | "Cannot get data for out-of-band error value"); |
||
97 | return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; |
||
98 | } |
||
99 | |||
100 | /// Returns the size of the data contained in this instance. |
||
101 | size_t size() const { |
||
102 | assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && |
||
103 | "Cannot get data for out-of-band error value"); |
||
104 | return R.Size; |
||
105 | } |
||
106 | |||
107 | /// Returns true if this value is equivalent to a default-constructed |
||
108 | /// WrapperFunctionResult. |
||
109 | bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; } |
||
110 | |||
111 | /// Create a WrapperFunctionResult with the given size and return a pointer |
||
112 | /// to the underlying memory. |
||
113 | static WrapperFunctionResult allocate(size_t Size) { |
||
114 | // Reset. |
||
115 | WrapperFunctionResult WFR; |
||
116 | WFR.R.Size = Size; |
||
117 | if (WFR.R.Size > sizeof(WFR.R.Data.Value)) |
||
118 | WFR.R.Data.ValuePtr = (char *)malloc(WFR.R.Size); |
||
119 | return WFR; |
||
120 | } |
||
121 | |||
122 | /// Copy from the given char range. |
||
123 | static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { |
||
124 | auto WFR = allocate(Size); |
||
125 | memcpy(WFR.data(), Source, Size); |
||
126 | return WFR; |
||
127 | } |
||
128 | |||
129 | /// Copy from the given null-terminated string (includes the null-terminator). |
||
130 | static WrapperFunctionResult copyFrom(const char *Source) { |
||
131 | return copyFrom(Source, strlen(Source) + 1); |
||
132 | } |
||
133 | |||
134 | /// Copy from the given std::string (includes the null terminator). |
||
135 | static WrapperFunctionResult copyFrom(const std::string &Source) { |
||
136 | return copyFrom(Source.c_str()); |
||
137 | } |
||
138 | |||
139 | /// Create an out-of-band error by copying the given string. |
||
140 | static WrapperFunctionResult createOutOfBandError(const char *Msg) { |
||
141 | // Reset. |
||
142 | WrapperFunctionResult WFR; |
||
143 | char *Tmp = (char *)malloc(strlen(Msg) + 1); |
||
144 | strcpy(Tmp, Msg); |
||
145 | WFR.R.Data.ValuePtr = Tmp; |
||
146 | return WFR; |
||
147 | } |
||
148 | |||
149 | /// Create an out-of-band error by copying the given string. |
||
150 | static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { |
||
151 | return createOutOfBandError(Msg.c_str()); |
||
152 | } |
||
153 | |||
154 | /// If this value is an out-of-band error then this returns the error message, |
||
155 | /// otherwise returns nullptr. |
||
156 | const char *getOutOfBandError() const { |
||
157 | return R.Size == 0 ? R.Data.ValuePtr : nullptr; |
||
158 | } |
||
159 | |||
160 | private: |
||
161 | static void init(CWrapperFunctionResult &R) { |
||
162 | R.Data.ValuePtr = nullptr; |
||
163 | R.Size = 0; |
||
164 | } |
||
165 | |||
166 | CWrapperFunctionResult R; |
||
167 | }; |
||
168 | |||
169 | namespace detail { |
||
170 | |||
171 | template <typename SPSArgListT, typename... ArgTs> |
||
172 | WrapperFunctionResult |
||
173 | serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { |
||
174 | auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...)); |
||
175 | SPSOutputBuffer OB(Result.data(), Result.size()); |
||
176 | if (!SPSArgListT::serialize(OB, Args...)) |
||
177 | return WrapperFunctionResult::createOutOfBandError( |
||
178 | "Error serializing arguments to blob in call"); |
||
179 | return Result; |
||
180 | } |
||
181 | |||
182 | template <typename RetT> class WrapperFunctionHandlerCaller { |
||
183 | public: |
||
184 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> |
||
185 | static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, |
||
186 | std::index_sequence<I...>) { |
||
187 | return std::forward<HandlerT>(H)(std::get<I>(Args)...); |
||
188 | } |
||
189 | }; |
||
190 | |||
191 | template <> class WrapperFunctionHandlerCaller<void> { |
||
192 | public: |
||
193 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> |
||
194 | static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, |
||
195 | std::index_sequence<I...>) { |
||
196 | std::forward<HandlerT>(H)(std::get<I>(Args)...); |
||
197 | return SPSEmpty(); |
||
198 | } |
||
199 | }; |
||
200 | |||
201 | template <typename WrapperFunctionImplT, |
||
202 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
203 | class WrapperFunctionHandlerHelper |
||
204 | : public WrapperFunctionHandlerHelper< |
||
205 | decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), |
||
206 | ResultSerializer, SPSTagTs...> {}; |
||
207 | |||
208 | template <typename RetT, typename... ArgTs, |
||
209 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
210 | class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
211 | SPSTagTs...> { |
||
212 | public: |
||
213 | using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; |
||
214 | using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; |
||
215 | |||
216 | template <typename HandlerT> |
||
217 | static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, |
||
218 | size_t ArgSize) { |
||
219 | ArgTuple Args; |
||
220 | if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) |
||
221 | return WrapperFunctionResult::createOutOfBandError( |
||
222 | "Could not deserialize arguments for wrapper function call"); |
||
223 | |||
224 | auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( |
||
225 | std::forward<HandlerT>(H), Args, ArgIndices{}); |
||
226 | |||
227 | return ResultSerializer<decltype(HandlerResult)>::serialize( |
||
228 | std::move(HandlerResult)); |
||
229 | } |
||
230 | |||
231 | private: |
||
232 | template <std::size_t... I> |
||
233 | static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, |
||
234 | std::index_sequence<I...>) { |
||
235 | SPSInputBuffer IB(ArgData, ArgSize); |
||
236 | return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); |
||
237 | } |
||
238 | }; |
||
239 | |||
240 | // Map function pointers to function types. |
||
241 | template <typename RetT, typename... ArgTs, |
||
242 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
243 | class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, |
||
244 | SPSTagTs...> |
||
245 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
246 | SPSTagTs...> {}; |
||
247 | |||
248 | // Map non-const member function types to function types. |
||
249 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
250 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
251 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, |
||
252 | SPSTagTs...> |
||
253 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
254 | SPSTagTs...> {}; |
||
255 | |||
256 | // Map const member function types to function types. |
||
257 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
258 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
259 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, |
||
260 | ResultSerializer, SPSTagTs...> |
||
261 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
262 | SPSTagTs...> {}; |
||
263 | |||
264 | template <typename WrapperFunctionImplT, |
||
265 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
266 | class WrapperFunctionAsyncHandlerHelper |
||
267 | : public WrapperFunctionAsyncHandlerHelper< |
||
268 | decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), |
||
269 | ResultSerializer, SPSTagTs...> {}; |
||
270 | |||
271 | template <typename RetT, typename SendResultT, typename... ArgTs, |
||
272 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
273 | class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...), |
||
274 | ResultSerializer, SPSTagTs...> { |
||
275 | public: |
||
276 | using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; |
||
277 | using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; |
||
278 | |||
279 | template <typename HandlerT, typename SendWrapperFunctionResultT> |
||
280 | static void applyAsync(HandlerT &&H, |
||
281 | SendWrapperFunctionResultT &&SendWrapperFunctionResult, |
||
282 | const char *ArgData, size_t ArgSize) { |
||
283 | ArgTuple Args; |
||
284 | if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) { |
||
285 | SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError( |
||
286 | "Could not deserialize arguments for wrapper function call")); |
||
287 | return; |
||
288 | } |
||
289 | |||
290 | auto SendResult = |
||
291 | [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable { |
||
292 | using ResultT = decltype(Result); |
||
293 | SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result))); |
||
294 | }; |
||
295 | |||
296 | callAsync(std::forward<HandlerT>(H), std::move(SendResult), std::move(Args), |
||
297 | ArgIndices{}); |
||
298 | } |
||
299 | |||
300 | private: |
||
301 | template <std::size_t... I> |
||
302 | static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, |
||
303 | std::index_sequence<I...>) { |
||
304 | SPSInputBuffer IB(ArgData, ArgSize); |
||
305 | return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); |
||
306 | } |
||
307 | |||
308 | template <typename HandlerT, typename SerializeAndSendResultT, |
||
309 | typename ArgTupleT, std::size_t... I> |
||
310 | static void callAsync(HandlerT &&H, |
||
311 | SerializeAndSendResultT &&SerializeAndSendResult, |
||
312 | ArgTupleT Args, std::index_sequence<I...>) { |
||
313 | (void)Args; // Silence a buggy GCC warning. |
||
314 | return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult), |
||
315 | std::move(std::get<I>(Args))...); |
||
316 | } |
||
317 | }; |
||
318 | |||
319 | // Map function pointers to function types. |
||
320 | template <typename RetT, typename... ArgTs, |
||
321 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
322 | class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, |
||
323 | SPSTagTs...> |
||
324 | : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
325 | SPSTagTs...> {}; |
||
326 | |||
327 | // Map non-const member function types to function types. |
||
328 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
329 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
330 | class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...), |
||
331 | ResultSerializer, SPSTagTs...> |
||
332 | : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
333 | SPSTagTs...> {}; |
||
334 | |||
335 | // Map const member function types to function types. |
||
336 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
337 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
338 | class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const, |
||
339 | ResultSerializer, SPSTagTs...> |
||
340 | : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
341 | SPSTagTs...> {}; |
||
342 | |||
343 | template <typename SPSRetTagT, typename RetT> class ResultSerializer { |
||
344 | public: |
||
345 | static WrapperFunctionResult serialize(RetT Result) { |
||
346 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
347 | Result); |
||
348 | } |
||
349 | }; |
||
350 | |||
351 | template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { |
||
352 | public: |
||
353 | static WrapperFunctionResult serialize(Error Err) { |
||
354 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
355 | toSPSSerializable(std::move(Err))); |
||
356 | } |
||
357 | }; |
||
358 | |||
359 | template <typename SPSRetTagT> |
||
360 | class ResultSerializer<SPSRetTagT, ErrorSuccess> { |
||
361 | public: |
||
362 | static WrapperFunctionResult serialize(ErrorSuccess Err) { |
||
363 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
364 | toSPSSerializable(std::move(Err))); |
||
365 | } |
||
366 | }; |
||
367 | |||
368 | template <typename SPSRetTagT, typename T> |
||
369 | class ResultSerializer<SPSRetTagT, Expected<T>> { |
||
370 | public: |
||
371 | static WrapperFunctionResult serialize(Expected<T> E) { |
||
372 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
373 | toSPSSerializable(std::move(E))); |
||
374 | } |
||
375 | }; |
||
376 | |||
377 | template <typename SPSRetTagT, typename RetT> class ResultDeserializer { |
||
378 | public: |
||
379 | static RetT makeValue() { return RetT(); } |
||
380 | static void makeSafe(RetT &Result) {} |
||
381 | |||
382 | static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { |
||
383 | SPSInputBuffer IB(ArgData, ArgSize); |
||
384 | if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) |
||
385 | return make_error<StringError>( |
||
386 | "Error deserializing return value from blob in call", |
||
387 | inconvertibleErrorCode()); |
||
388 | return Error::success(); |
||
389 | } |
||
390 | }; |
||
391 | |||
392 | template <> class ResultDeserializer<SPSError, Error> { |
||
393 | public: |
||
394 | static Error makeValue() { return Error::success(); } |
||
395 | static void makeSafe(Error &Err) { cantFail(std::move(Err)); } |
||
396 | |||
397 | static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { |
||
398 | SPSInputBuffer IB(ArgData, ArgSize); |
||
399 | SPSSerializableError BSE; |
||
400 | if (!SPSArgList<SPSError>::deserialize(IB, BSE)) |
||
401 | return make_error<StringError>( |
||
402 | "Error deserializing return value from blob in call", |
||
403 | inconvertibleErrorCode()); |
||
404 | Err = fromSPSSerializable(std::move(BSE)); |
||
405 | return Error::success(); |
||
406 | } |
||
407 | }; |
||
408 | |||
409 | template <typename SPSTagT, typename T> |
||
410 | class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { |
||
411 | public: |
||
412 | static Expected<T> makeValue() { return T(); } |
||
413 | static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } |
||
414 | |||
415 | static Error deserialize(Expected<T> &E, const char *ArgData, |
||
416 | size_t ArgSize) { |
||
417 | SPSInputBuffer IB(ArgData, ArgSize); |
||
418 | SPSSerializableExpected<T> BSE; |
||
419 | if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) |
||
420 | return make_error<StringError>( |
||
421 | "Error deserializing return value from blob in call", |
||
422 | inconvertibleErrorCode()); |
||
423 | E = fromSPSSerializable(std::move(BSE)); |
||
424 | return Error::success(); |
||
425 | } |
||
426 | }; |
||
427 | |||
428 | template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper { |
||
429 | // Did you forget to use Error / Expected in your handler? |
||
430 | }; |
||
431 | |||
432 | } // end namespace detail |
||
433 | |||
434 | template <typename SPSSignature> class WrapperFunction; |
||
435 | |||
436 | template <typename SPSRetTagT, typename... SPSTagTs> |
||
437 | class WrapperFunction<SPSRetTagT(SPSTagTs...)> { |
||
438 | private: |
||
439 | template <typename RetT> |
||
440 | using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; |
||
441 | |||
442 | public: |
||
443 | /// Call a wrapper function. Caller should be callable as |
||
444 | /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize); |
||
445 | template <typename CallerFn, typename RetT, typename... ArgTs> |
||
446 | static Error call(const CallerFn &Caller, RetT &Result, |
||
447 | const ArgTs &...Args) { |
||
448 | |||
449 | // RetT might be an Error or Expected value. Set the checked flag now: |
||
450 | // we don't want the user to have to check the unused result if this |
||
451 | // operation fails. |
||
452 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); |
||
453 | |||
454 | auto ArgBuffer = |
||
455 | detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( |
||
456 | Args...); |
||
457 | if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) |
||
458 | return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); |
||
459 | |||
460 | WrapperFunctionResult ResultBuffer = |
||
461 | Caller(ArgBuffer.data(), ArgBuffer.size()); |
||
462 | if (auto ErrMsg = ResultBuffer.getOutOfBandError()) |
||
463 | return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); |
||
464 | |||
465 | return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( |
||
466 | Result, ResultBuffer.data(), ResultBuffer.size()); |
||
467 | } |
||
468 | |||
469 | /// Call an async wrapper function. |
||
470 | /// Caller should be callable as |
||
471 | /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult, |
||
472 | /// WrapperFunctionResult ArgBuffer); |
||
473 | template <typename AsyncCallerFn, typename SendDeserializedResultFn, |
||
474 | typename... ArgTs> |
||
475 | static void callAsync(AsyncCallerFn &&Caller, |
||
476 | SendDeserializedResultFn &&SendDeserializedResult, |
||
477 | const ArgTs &...Args) { |
||
478 | using RetT = typename std::tuple_element< |
||
479 | 1, typename detail::WrapperFunctionHandlerHelper< |
||
480 | std::remove_reference_t<SendDeserializedResultFn>, |
||
481 | ResultSerializer, SPSRetTagT>::ArgTuple>::type; |
||
482 | |||
483 | auto ArgBuffer = |
||
484 | detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( |
||
485 | Args...); |
||
486 | if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) { |
||
487 | SendDeserializedResult( |
||
488 | make_error<StringError>(ErrMsg, inconvertibleErrorCode()), |
||
489 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue()); |
||
490 | return; |
||
491 | } |
||
492 | |||
493 | auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)]( |
||
494 | WrapperFunctionResult R) mutable { |
||
495 | RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue(); |
||
496 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal); |
||
497 | |||
498 | if (auto *ErrMsg = R.getOutOfBandError()) { |
||
499 | SDR(make_error<StringError>(ErrMsg, inconvertibleErrorCode()), |
||
500 | std::move(RetVal)); |
||
501 | return; |
||
502 | } |
||
503 | |||
504 | SPSInputBuffer IB(R.data(), R.size()); |
||
505 | if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( |
||
506 | RetVal, R.data(), R.size())) |
||
507 | SDR(std::move(Err), std::move(RetVal)); |
||
508 | |||
509 | SDR(Error::success(), std::move(RetVal)); |
||
510 | }; |
||
511 | |||
512 | Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size()); |
||
513 | } |
||
514 | |||
515 | /// Handle a call to a wrapper function. |
||
516 | template <typename HandlerT> |
||
517 | static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, |
||
518 | HandlerT &&Handler) { |
||
519 | using WFHH = |
||
520 | detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, |
||
521 | ResultSerializer, SPSTagTs...>; |
||
522 | return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); |
||
523 | } |
||
524 | |||
525 | /// Handle a call to an async wrapper function. |
||
526 | template <typename HandlerT, typename SendResultT> |
||
527 | static void handleAsync(const char *ArgData, size_t ArgSize, |
||
528 | HandlerT &&Handler, SendResultT &&SendResult) { |
||
529 | using WFAHH = detail::WrapperFunctionAsyncHandlerHelper< |
||
530 | std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>; |
||
531 | WFAHH::applyAsync(std::forward<HandlerT>(Handler), |
||
532 | std::forward<SendResultT>(SendResult), ArgData, ArgSize); |
||
533 | } |
||
534 | |||
535 | private: |
||
536 | template <typename T> static const T &makeSerializable(const T &Value) { |
||
537 | return Value; |
||
538 | } |
||
539 | |||
540 | static detail::SPSSerializableError makeSerializable(Error Err) { |
||
541 | return detail::toSPSSerializable(std::move(Err)); |
||
542 | } |
||
543 | |||
544 | template <typename T> |
||
545 | static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { |
||
546 | return detail::toSPSSerializable(std::move(E)); |
||
547 | } |
||
548 | }; |
||
549 | |||
550 | template <typename... SPSTagTs> |
||
551 | class WrapperFunction<void(SPSTagTs...)> |
||
552 | : private WrapperFunction<SPSEmpty(SPSTagTs...)> { |
||
553 | |||
554 | public: |
||
555 | template <typename CallerFn, typename... ArgTs> |
||
556 | static Error call(const CallerFn &Caller, const ArgTs &...Args) { |
||
557 | SPSEmpty BE; |
||
558 | return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...); |
||
559 | } |
||
560 | |||
561 | template <typename AsyncCallerFn, typename SendDeserializedResultFn, |
||
562 | typename... ArgTs> |
||
563 | static void callAsync(AsyncCallerFn &&Caller, |
||
564 | SendDeserializedResultFn &&SendDeserializedResult, |
||
565 | const ArgTs &...Args) { |
||
566 | WrapperFunction<SPSEmpty(SPSTagTs...)>::callAsync( |
||
567 | std::forward<AsyncCallerFn>(Caller), |
||
568 | [SDR = std::move(SendDeserializedResult)](Error SerializeErr, |
||
569 | SPSEmpty E) mutable { |
||
570 | SDR(std::move(SerializeErr)); |
||
571 | }, |
||
572 | Args...); |
||
573 | } |
||
574 | |||
575 | using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; |
||
576 | using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync; |
||
577 | }; |
||
578 | |||
579 | /// A function object that takes an ExecutorAddr as its first argument, |
||
580 | /// casts that address to a ClassT*, then calls the given method on that |
||
581 | /// pointer passing in the remaining function arguments. This utility |
||
582 | /// removes some of the boilerplate from writing wrappers for method calls. |
||
583 | /// |
||
584 | /// @code{.cpp} |
||
585 | /// class MyClass { |
||
586 | /// public: |
||
587 | /// void myMethod(uint32_t, bool) { ... } |
||
588 | /// }; |
||
589 | /// |
||
590 | /// // SPS Method signature -- note MyClass object address as first argument. |
||
591 | /// using SPSMyMethodWrapperSignature = |
||
592 | /// SPSTuple<SPSExecutorAddr, uint32_t, bool>; |
||
593 | /// |
||
594 | /// WrapperFunctionResult |
||
595 | /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { |
||
596 | /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( |
||
597 | /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); |
||
598 | /// } |
||
599 | /// @endcode |
||
600 | /// |
||
601 | template <typename RetT, typename ClassT, typename... ArgTs> |
||
602 | class MethodWrapperHandler { |
||
603 | public: |
||
604 | using MethodT = RetT (ClassT::*)(ArgTs...); |
||
605 | MethodWrapperHandler(MethodT M) : M(M) {} |
||
606 | RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { |
||
607 | return (ObjAddr.toPtr<ClassT*>()->*M)(std::forward<ArgTs>(Args)...); |
||
608 | } |
||
609 | |||
610 | private: |
||
611 | MethodT M; |
||
612 | }; |
||
613 | |||
614 | /// Create a MethodWrapperHandler object from the given method pointer. |
||
615 | template <typename RetT, typename ClassT, typename... ArgTs> |
||
616 | MethodWrapperHandler<RetT, ClassT, ArgTs...> |
||
617 | makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { |
||
618 | return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); |
||
619 | } |
||
620 | |||
621 | /// Represents a serialized wrapper function call. |
||
622 | /// Serializing calls themselves allows us to batch them: We can make one |
||
623 | /// "run-wrapper-functions" utility and send it a list of calls to run. |
||
624 | /// |
||
625 | /// The motivating use-case for this API is JITLink allocation actions, where |
||
626 | /// we want to run multiple functions to finalize linked memory without having |
||
627 | /// to make separate IPC calls for each one. |
||
628 | class WrapperFunctionCall { |
||
629 | public: |
||
630 | using ArgDataBufferType = SmallVector<char, 24>; |
||
631 | |||
632 | /// Create a WrapperFunctionCall using the given SPS serializer to serialize |
||
633 | /// the arguments. |
||
634 | template <typename SPSSerializer, typename... ArgTs> |
||
635 | static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr, |
||
636 | const ArgTs &...Args) { |
||
637 | ArgDataBufferType ArgData; |
||
638 | ArgData.resize(SPSSerializer::size(Args...)); |
||
639 | SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), |
||
640 | ArgData.size()); |
||
641 | if (SPSSerializer::serialize(OB, Args...)) |
||
642 | return WrapperFunctionCall(FnAddr, std::move(ArgData)); |
||
643 | return make_error<StringError>("Cannot serialize arguments for " |
||
644 | "AllocActionCall", |
||
645 | inconvertibleErrorCode()); |
||
646 | } |
||
647 | |||
648 | WrapperFunctionCall() = default; |
||
649 | |||
650 | /// Create a WrapperFunctionCall from a target function and arg buffer. |
||
651 | WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) |
||
652 | : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} |
||
653 | |||
654 | /// Returns the address to be called. |
||
655 | const ExecutorAddr &getCallee() const { return FnAddr; } |
||
656 | |||
657 | /// Returns the argument data. |
||
658 | const ArgDataBufferType &getArgData() const { return ArgData; } |
||
659 | |||
660 | /// WrapperFunctionCalls convert to true if the callee is non-null. |
||
661 | explicit operator bool() const { return !!FnAddr; } |
||
662 | |||
663 | /// Run call returning raw WrapperFunctionResult. |
||
664 | shared::WrapperFunctionResult run() const { |
||
665 | using FnTy = |
||
666 | shared::CWrapperFunctionResult(const char *ArgData, size_t ArgSize); |
||
667 | return shared::WrapperFunctionResult( |
||
668 | FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size())); |
||
669 | } |
||
670 | |||
671 | /// Run call and deserialize result using SPS. |
||
672 | template <typename SPSRetT, typename RetT> |
||
673 | std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error> |
||
674 | runWithSPSRet(RetT &RetVal) const { |
||
675 | auto WFR = run(); |
||
676 | if (const char *ErrMsg = WFR.getOutOfBandError()) |
||
677 | return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); |
||
678 | shared::SPSInputBuffer IB(WFR.data(), WFR.size()); |
||
679 | if (!shared::SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) |
||
680 | return make_error<StringError>("Could not deserialize result from " |
||
681 | "serialized wrapper function call", |
||
682 | inconvertibleErrorCode()); |
||
683 | return Error::success(); |
||
684 | } |
||
685 | |||
686 | /// Overload for SPS functions returning void. |
||
687 | template <typename SPSRetT> |
||
688 | std::enable_if_t<std::is_same<SPSRetT, void>::value, Error> |
||
689 | runWithSPSRet() const { |
||
690 | shared::SPSEmpty E; |
||
691 | return runWithSPSRet<shared::SPSEmpty>(E); |
||
692 | } |
||
693 | |||
694 | /// Run call and deserialize an SPSError result. SPSError returns and |
||
695 | /// deserialization failures are merged into the returned error. |
||
696 | Error runWithSPSRetErrorMerged() const { |
||
697 | detail::SPSSerializableError RetErr; |
||
698 | if (auto Err = runWithSPSRet<SPSError>(RetErr)) |
||
699 | return Err; |
||
700 | return detail::fromSPSSerializable(std::move(RetErr)); |
||
701 | } |
||
702 | |||
703 | private: |
||
704 | orc::ExecutorAddr FnAddr; |
||
705 | ArgDataBufferType ArgData; |
||
706 | }; |
||
707 | |||
708 | using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>; |
||
709 | |||
710 | template <> |
||
711 | class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { |
||
712 | public: |
||
713 | static size_t size(const WrapperFunctionCall &WFC) { |
||
714 | return SPSWrapperFunctionCall::AsArgList::size(WFC.getCallee(), |
||
715 | WFC.getArgData()); |
||
716 | } |
||
717 | |||
718 | static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { |
||
719 | return SPSWrapperFunctionCall::AsArgList::serialize(OB, WFC.getCallee(), |
||
720 | WFC.getArgData()); |
||
721 | } |
||
722 | |||
723 | static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { |
||
724 | ExecutorAddr FnAddr; |
||
725 | WrapperFunctionCall::ArgDataBufferType ArgData; |
||
726 | if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) |
||
727 | return false; |
||
728 | WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); |
||
729 | return true; |
||
730 | } |
||
731 | }; |
||
732 | |||
733 | } // end namespace shared |
||
734 | } // end namespace orc |
||
735 | } // end namespace llvm |
||
736 | |||
737 | #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H |