Details | Last modification | View Log | RSS feed
| Rev | Author | Line No. | Line |
|---|---|---|---|
| 14 | pmbaty | 1 | //===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===// |
| 2 | // |
||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
||
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
||
| 6 | // |
||
| 7 | //===----------------------------------------------------------------------===// |
||
| 8 | // |
||
| 9 | // A buffer for serialized results. |
||
| 10 | // |
||
| 11 | //===----------------------------------------------------------------------===// |
||
| 12 | |||
| 13 | #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H |
||
| 14 | #define LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H |
||
| 15 | |||
| 16 | #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" |
||
| 17 | #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h" |
||
| 18 | #include "llvm/Support/Error.h" |
||
| 19 | |||
| 20 | #include <type_traits> |
||
| 21 | |||
| 22 | namespace llvm { |
||
| 23 | namespace orc { |
||
| 24 | namespace shared { |
||
| 25 | |||
| 26 | // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. |
||
| 27 | union CWrapperFunctionResultDataUnion { |
||
| 28 | char *ValuePtr; |
||
| 29 | char Value[sizeof(ValuePtr)]; |
||
| 30 | }; |
||
| 31 | |||
| 32 | // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. |
||
| 33 | typedef struct { |
||
| 34 | CWrapperFunctionResultDataUnion Data; |
||
| 35 | size_t Size; |
||
| 36 | } CWrapperFunctionResult; |
||
| 37 | |||
| 38 | /// C++ wrapper function result: Same as CWrapperFunctionResult but |
||
| 39 | /// auto-releases memory. |
||
| 40 | class WrapperFunctionResult { |
||
| 41 | public: |
||
| 42 | /// Create a default WrapperFunctionResult. |
||
| 43 | WrapperFunctionResult() { init(R); } |
||
| 44 | |||
| 45 | /// Create a WrapperFunctionResult by taking ownership of a |
||
| 46 | /// CWrapperFunctionResult. |
||
| 47 | /// |
||
| 48 | /// Warning: This should only be used by clients writing wrapper-function |
||
| 49 | /// caller utilities (like TargetProcessControl). |
||
| 50 | WrapperFunctionResult(CWrapperFunctionResult R) : R(R) { |
||
| 51 | // Reset R. |
||
| 52 | init(R); |
||
| 53 | } |
||
| 54 | |||
| 55 | WrapperFunctionResult(const WrapperFunctionResult &) = delete; |
||
| 56 | WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; |
||
| 57 | |||
| 58 | WrapperFunctionResult(WrapperFunctionResult &&Other) { |
||
| 59 | init(R); |
||
| 60 | std::swap(R, Other.R); |
||
| 61 | } |
||
| 62 | |||
| 63 | WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { |
||
| 64 | WrapperFunctionResult Tmp(std::move(Other)); |
||
| 65 | std::swap(R, Tmp.R); |
||
| 66 | return *this; |
||
| 67 | } |
||
| 68 | |||
| 69 | ~WrapperFunctionResult() { |
||
| 70 | if ((R.Size > sizeof(R.Data.Value)) || |
||
| 71 | (R.Size == 0 && R.Data.ValuePtr != nullptr)) |
||
| 72 | free(R.Data.ValuePtr); |
||
| 73 | } |
||
| 74 | |||
| 75 | /// Release ownership of the contained CWrapperFunctionResult. |
||
| 76 | /// Warning: Do not use -- this method will be removed in the future. It only |
||
| 77 | /// exists to temporarily support some code that will eventually be moved to |
||
| 78 | /// the ORC runtime. |
||
| 79 | CWrapperFunctionResult release() { |
||
| 80 | CWrapperFunctionResult Tmp; |
||
| 81 | init(Tmp); |
||
| 82 | std::swap(R, Tmp); |
||
| 83 | return Tmp; |
||
| 84 | } |
||
| 85 | |||
| 86 | /// Get a pointer to the data contained in this instance. |
||
| 87 | char *data() { |
||
| 88 | assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && |
||
| 89 | "Cannot get data for out-of-band error value"); |
||
| 90 | return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; |
||
| 91 | } |
||
| 92 | |||
| 93 | /// Get a const pointer to the data contained in this instance. |
||
| 94 | const char *data() const { |
||
| 95 | assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && |
||
| 96 | "Cannot get data for out-of-band error value"); |
||
| 97 | return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; |
||
| 98 | } |
||
| 99 | |||
| 100 | /// Returns the size of the data contained in this instance. |
||
| 101 | size_t size() const { |
||
| 102 | assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && |
||
| 103 | "Cannot get data for out-of-band error value"); |
||
| 104 | return R.Size; |
||
| 105 | } |
||
| 106 | |||
| 107 | /// Returns true if this value is equivalent to a default-constructed |
||
| 108 | /// WrapperFunctionResult. |
||
| 109 | bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; } |
||
| 110 | |||
| 111 | /// Create a WrapperFunctionResult with the given size and return a pointer |
||
| 112 | /// to the underlying memory. |
||
| 113 | static WrapperFunctionResult allocate(size_t Size) { |
||
| 114 | // Reset. |
||
| 115 | WrapperFunctionResult WFR; |
||
| 116 | WFR.R.Size = Size; |
||
| 117 | if (WFR.R.Size > sizeof(WFR.R.Data.Value)) |
||
| 118 | WFR.R.Data.ValuePtr = (char *)malloc(WFR.R.Size); |
||
| 119 | return WFR; |
||
| 120 | } |
||
| 121 | |||
| 122 | /// Copy from the given char range. |
||
| 123 | static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { |
||
| 124 | auto WFR = allocate(Size); |
||
| 125 | memcpy(WFR.data(), Source, Size); |
||
| 126 | return WFR; |
||
| 127 | } |
||
| 128 | |||
| 129 | /// Copy from the given null-terminated string (includes the null-terminator). |
||
| 130 | static WrapperFunctionResult copyFrom(const char *Source) { |
||
| 131 | return copyFrom(Source, strlen(Source) + 1); |
||
| 132 | } |
||
| 133 | |||
| 134 | /// Copy from the given std::string (includes the null terminator). |
||
| 135 | static WrapperFunctionResult copyFrom(const std::string &Source) { |
||
| 136 | return copyFrom(Source.c_str()); |
||
| 137 | } |
||
| 138 | |||
| 139 | /// Create an out-of-band error by copying the given string. |
||
| 140 | static WrapperFunctionResult createOutOfBandError(const char *Msg) { |
||
| 141 | // Reset. |
||
| 142 | WrapperFunctionResult WFR; |
||
| 143 | char *Tmp = (char *)malloc(strlen(Msg) + 1); |
||
| 144 | strcpy(Tmp, Msg); |
||
| 145 | WFR.R.Data.ValuePtr = Tmp; |
||
| 146 | return WFR; |
||
| 147 | } |
||
| 148 | |||
| 149 | /// Create an out-of-band error by copying the given string. |
||
| 150 | static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { |
||
| 151 | return createOutOfBandError(Msg.c_str()); |
||
| 152 | } |
||
| 153 | |||
| 154 | /// If this value is an out-of-band error then this returns the error message, |
||
| 155 | /// otherwise returns nullptr. |
||
| 156 | const char *getOutOfBandError() const { |
||
| 157 | return R.Size == 0 ? R.Data.ValuePtr : nullptr; |
||
| 158 | } |
||
| 159 | |||
| 160 | private: |
||
| 161 | static void init(CWrapperFunctionResult &R) { |
||
| 162 | R.Data.ValuePtr = nullptr; |
||
| 163 | R.Size = 0; |
||
| 164 | } |
||
| 165 | |||
| 166 | CWrapperFunctionResult R; |
||
| 167 | }; |
||
| 168 | |||
| 169 | namespace detail { |
||
| 170 | |||
| 171 | template <typename SPSArgListT, typename... ArgTs> |
||
| 172 | WrapperFunctionResult |
||
| 173 | serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { |
||
| 174 | auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...)); |
||
| 175 | SPSOutputBuffer OB(Result.data(), Result.size()); |
||
| 176 | if (!SPSArgListT::serialize(OB, Args...)) |
||
| 177 | return WrapperFunctionResult::createOutOfBandError( |
||
| 178 | "Error serializing arguments to blob in call"); |
||
| 179 | return Result; |
||
| 180 | } |
||
| 181 | |||
| 182 | template <typename RetT> class WrapperFunctionHandlerCaller { |
||
| 183 | public: |
||
| 184 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> |
||
| 185 | static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, |
||
| 186 | std::index_sequence<I...>) { |
||
| 187 | return std::forward<HandlerT>(H)(std::get<I>(Args)...); |
||
| 188 | } |
||
| 189 | }; |
||
| 190 | |||
| 191 | template <> class WrapperFunctionHandlerCaller<void> { |
||
| 192 | public: |
||
| 193 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> |
||
| 194 | static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, |
||
| 195 | std::index_sequence<I...>) { |
||
| 196 | std::forward<HandlerT>(H)(std::get<I>(Args)...); |
||
| 197 | return SPSEmpty(); |
||
| 198 | } |
||
| 199 | }; |
||
| 200 | |||
| 201 | template <typename WrapperFunctionImplT, |
||
| 202 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 203 | class WrapperFunctionHandlerHelper |
||
| 204 | : public WrapperFunctionHandlerHelper< |
||
| 205 | decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), |
||
| 206 | ResultSerializer, SPSTagTs...> {}; |
||
| 207 | |||
| 208 | template <typename RetT, typename... ArgTs, |
||
| 209 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 210 | class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
| 211 | SPSTagTs...> { |
||
| 212 | public: |
||
| 213 | using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; |
||
| 214 | using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; |
||
| 215 | |||
| 216 | template <typename HandlerT> |
||
| 217 | static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, |
||
| 218 | size_t ArgSize) { |
||
| 219 | ArgTuple Args; |
||
| 220 | if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) |
||
| 221 | return WrapperFunctionResult::createOutOfBandError( |
||
| 222 | "Could not deserialize arguments for wrapper function call"); |
||
| 223 | |||
| 224 | auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( |
||
| 225 | std::forward<HandlerT>(H), Args, ArgIndices{}); |
||
| 226 | |||
| 227 | return ResultSerializer<decltype(HandlerResult)>::serialize( |
||
| 228 | std::move(HandlerResult)); |
||
| 229 | } |
||
| 230 | |||
| 231 | private: |
||
| 232 | template <std::size_t... I> |
||
| 233 | static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, |
||
| 234 | std::index_sequence<I...>) { |
||
| 235 | SPSInputBuffer IB(ArgData, ArgSize); |
||
| 236 | return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); |
||
| 237 | } |
||
| 238 | }; |
||
| 239 | |||
| 240 | // Map function pointers to function types. |
||
| 241 | template <typename RetT, typename... ArgTs, |
||
| 242 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 243 | class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, |
||
| 244 | SPSTagTs...> |
||
| 245 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
| 246 | SPSTagTs...> {}; |
||
| 247 | |||
| 248 | // Map non-const member function types to function types. |
||
| 249 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
| 250 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 251 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, |
||
| 252 | SPSTagTs...> |
||
| 253 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
| 254 | SPSTagTs...> {}; |
||
| 255 | |||
| 256 | // Map const member function types to function types. |
||
| 257 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
| 258 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 259 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, |
||
| 260 | ResultSerializer, SPSTagTs...> |
||
| 261 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
| 262 | SPSTagTs...> {}; |
||
| 263 | |||
| 264 | template <typename WrapperFunctionImplT, |
||
| 265 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 266 | class WrapperFunctionAsyncHandlerHelper |
||
| 267 | : public WrapperFunctionAsyncHandlerHelper< |
||
| 268 | decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), |
||
| 269 | ResultSerializer, SPSTagTs...> {}; |
||
| 270 | |||
| 271 | template <typename RetT, typename SendResultT, typename... ArgTs, |
||
| 272 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 273 | class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...), |
||
| 274 | ResultSerializer, SPSTagTs...> { |
||
| 275 | public: |
||
| 276 | using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; |
||
| 277 | using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; |
||
| 278 | |||
| 279 | template <typename HandlerT, typename SendWrapperFunctionResultT> |
||
| 280 | static void applyAsync(HandlerT &&H, |
||
| 281 | SendWrapperFunctionResultT &&SendWrapperFunctionResult, |
||
| 282 | const char *ArgData, size_t ArgSize) { |
||
| 283 | ArgTuple Args; |
||
| 284 | if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) { |
||
| 285 | SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError( |
||
| 286 | "Could not deserialize arguments for wrapper function call")); |
||
| 287 | return; |
||
| 288 | } |
||
| 289 | |||
| 290 | auto SendResult = |
||
| 291 | [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable { |
||
| 292 | using ResultT = decltype(Result); |
||
| 293 | SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result))); |
||
| 294 | }; |
||
| 295 | |||
| 296 | callAsync(std::forward<HandlerT>(H), std::move(SendResult), std::move(Args), |
||
| 297 | ArgIndices{}); |
||
| 298 | } |
||
| 299 | |||
| 300 | private: |
||
| 301 | template <std::size_t... I> |
||
| 302 | static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, |
||
| 303 | std::index_sequence<I...>) { |
||
| 304 | SPSInputBuffer IB(ArgData, ArgSize); |
||
| 305 | return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); |
||
| 306 | } |
||
| 307 | |||
| 308 | template <typename HandlerT, typename SerializeAndSendResultT, |
||
| 309 | typename ArgTupleT, std::size_t... I> |
||
| 310 | static void callAsync(HandlerT &&H, |
||
| 311 | SerializeAndSendResultT &&SerializeAndSendResult, |
||
| 312 | ArgTupleT Args, std::index_sequence<I...>) { |
||
| 313 | (void)Args; // Silence a buggy GCC warning. |
||
| 314 | return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult), |
||
| 315 | std::move(std::get<I>(Args))...); |
||
| 316 | } |
||
| 317 | }; |
||
| 318 | |||
| 319 | // Map function pointers to function types. |
||
| 320 | template <typename RetT, typename... ArgTs, |
||
| 321 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 322 | class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, |
||
| 323 | SPSTagTs...> |
||
| 324 | : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
| 325 | SPSTagTs...> {}; |
||
| 326 | |||
| 327 | // Map non-const member function types to function types. |
||
| 328 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
| 329 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 330 | class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...), |
||
| 331 | ResultSerializer, SPSTagTs...> |
||
| 332 | : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
| 333 | SPSTagTs...> {}; |
||
| 334 | |||
| 335 | // Map const member function types to function types. |
||
| 336 | template <typename ClassT, typename RetT, typename... ArgTs, |
||
| 337 | template <typename> class ResultSerializer, typename... SPSTagTs> |
||
| 338 | class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const, |
||
| 339 | ResultSerializer, SPSTagTs...> |
||
| 340 | : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, |
||
| 341 | SPSTagTs...> {}; |
||
| 342 | |||
| 343 | template <typename SPSRetTagT, typename RetT> class ResultSerializer { |
||
| 344 | public: |
||
| 345 | static WrapperFunctionResult serialize(RetT Result) { |
||
| 346 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
| 347 | Result); |
||
| 348 | } |
||
| 349 | }; |
||
| 350 | |||
| 351 | template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { |
||
| 352 | public: |
||
| 353 | static WrapperFunctionResult serialize(Error Err) { |
||
| 354 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
| 355 | toSPSSerializable(std::move(Err))); |
||
| 356 | } |
||
| 357 | }; |
||
| 358 | |||
| 359 | template <typename SPSRetTagT> |
||
| 360 | class ResultSerializer<SPSRetTagT, ErrorSuccess> { |
||
| 361 | public: |
||
| 362 | static WrapperFunctionResult serialize(ErrorSuccess Err) { |
||
| 363 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
| 364 | toSPSSerializable(std::move(Err))); |
||
| 365 | } |
||
| 366 | }; |
||
| 367 | |||
| 368 | template <typename SPSRetTagT, typename T> |
||
| 369 | class ResultSerializer<SPSRetTagT, Expected<T>> { |
||
| 370 | public: |
||
| 371 | static WrapperFunctionResult serialize(Expected<T> E) { |
||
| 372 | return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( |
||
| 373 | toSPSSerializable(std::move(E))); |
||
| 374 | } |
||
| 375 | }; |
||
| 376 | |||
| 377 | template <typename SPSRetTagT, typename RetT> class ResultDeserializer { |
||
| 378 | public: |
||
| 379 | static RetT makeValue() { return RetT(); } |
||
| 380 | static void makeSafe(RetT &Result) {} |
||
| 381 | |||
| 382 | static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { |
||
| 383 | SPSInputBuffer IB(ArgData, ArgSize); |
||
| 384 | if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) |
||
| 385 | return make_error<StringError>( |
||
| 386 | "Error deserializing return value from blob in call", |
||
| 387 | inconvertibleErrorCode()); |
||
| 388 | return Error::success(); |
||
| 389 | } |
||
| 390 | }; |
||
| 391 | |||
| 392 | template <> class ResultDeserializer<SPSError, Error> { |
||
| 393 | public: |
||
| 394 | static Error makeValue() { return Error::success(); } |
||
| 395 | static void makeSafe(Error &Err) { cantFail(std::move(Err)); } |
||
| 396 | |||
| 397 | static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { |
||
| 398 | SPSInputBuffer IB(ArgData, ArgSize); |
||
| 399 | SPSSerializableError BSE; |
||
| 400 | if (!SPSArgList<SPSError>::deserialize(IB, BSE)) |
||
| 401 | return make_error<StringError>( |
||
| 402 | "Error deserializing return value from blob in call", |
||
| 403 | inconvertibleErrorCode()); |
||
| 404 | Err = fromSPSSerializable(std::move(BSE)); |
||
| 405 | return Error::success(); |
||
| 406 | } |
||
| 407 | }; |
||
| 408 | |||
| 409 | template <typename SPSTagT, typename T> |
||
| 410 | class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { |
||
| 411 | public: |
||
| 412 | static Expected<T> makeValue() { return T(); } |
||
| 413 | static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } |
||
| 414 | |||
| 415 | static Error deserialize(Expected<T> &E, const char *ArgData, |
||
| 416 | size_t ArgSize) { |
||
| 417 | SPSInputBuffer IB(ArgData, ArgSize); |
||
| 418 | SPSSerializableExpected<T> BSE; |
||
| 419 | if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) |
||
| 420 | return make_error<StringError>( |
||
| 421 | "Error deserializing return value from blob in call", |
||
| 422 | inconvertibleErrorCode()); |
||
| 423 | E = fromSPSSerializable(std::move(BSE)); |
||
| 424 | return Error::success(); |
||
| 425 | } |
||
| 426 | }; |
||
| 427 | |||
| 428 | template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper { |
||
| 429 | // Did you forget to use Error / Expected in your handler? |
||
| 430 | }; |
||
| 431 | |||
| 432 | } // end namespace detail |
||
| 433 | |||
| 434 | template <typename SPSSignature> class WrapperFunction; |
||
| 435 | |||
| 436 | template <typename SPSRetTagT, typename... SPSTagTs> |
||
| 437 | class WrapperFunction<SPSRetTagT(SPSTagTs...)> { |
||
| 438 | private: |
||
| 439 | template <typename RetT> |
||
| 440 | using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; |
||
| 441 | |||
| 442 | public: |
||
| 443 | /// Call a wrapper function. Caller should be callable as |
||
| 444 | /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize); |
||
| 445 | template <typename CallerFn, typename RetT, typename... ArgTs> |
||
| 446 | static Error call(const CallerFn &Caller, RetT &Result, |
||
| 447 | const ArgTs &...Args) { |
||
| 448 | |||
| 449 | // RetT might be an Error or Expected value. Set the checked flag now: |
||
| 450 | // we don't want the user to have to check the unused result if this |
||
| 451 | // operation fails. |
||
| 452 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); |
||
| 453 | |||
| 454 | auto ArgBuffer = |
||
| 455 | detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( |
||
| 456 | Args...); |
||
| 457 | if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) |
||
| 458 | return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); |
||
| 459 | |||
| 460 | WrapperFunctionResult ResultBuffer = |
||
| 461 | Caller(ArgBuffer.data(), ArgBuffer.size()); |
||
| 462 | if (auto ErrMsg = ResultBuffer.getOutOfBandError()) |
||
| 463 | return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); |
||
| 464 | |||
| 465 | return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( |
||
| 466 | Result, ResultBuffer.data(), ResultBuffer.size()); |
||
| 467 | } |
||
| 468 | |||
| 469 | /// Call an async wrapper function. |
||
| 470 | /// Caller should be callable as |
||
| 471 | /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult, |
||
| 472 | /// WrapperFunctionResult ArgBuffer); |
||
| 473 | template <typename AsyncCallerFn, typename SendDeserializedResultFn, |
||
| 474 | typename... ArgTs> |
||
| 475 | static void callAsync(AsyncCallerFn &&Caller, |
||
| 476 | SendDeserializedResultFn &&SendDeserializedResult, |
||
| 477 | const ArgTs &...Args) { |
||
| 478 | using RetT = typename std::tuple_element< |
||
| 479 | 1, typename detail::WrapperFunctionHandlerHelper< |
||
| 480 | std::remove_reference_t<SendDeserializedResultFn>, |
||
| 481 | ResultSerializer, SPSRetTagT>::ArgTuple>::type; |
||
| 482 | |||
| 483 | auto ArgBuffer = |
||
| 484 | detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( |
||
| 485 | Args...); |
||
| 486 | if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) { |
||
| 487 | SendDeserializedResult( |
||
| 488 | make_error<StringError>(ErrMsg, inconvertibleErrorCode()), |
||
| 489 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue()); |
||
| 490 | return; |
||
| 491 | } |
||
| 492 | |||
| 493 | auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)]( |
||
| 494 | WrapperFunctionResult R) mutable { |
||
| 495 | RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue(); |
||
| 496 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal); |
||
| 497 | |||
| 498 | if (auto *ErrMsg = R.getOutOfBandError()) { |
||
| 499 | SDR(make_error<StringError>(ErrMsg, inconvertibleErrorCode()), |
||
| 500 | std::move(RetVal)); |
||
| 501 | return; |
||
| 502 | } |
||
| 503 | |||
| 504 | SPSInputBuffer IB(R.data(), R.size()); |
||
| 505 | if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( |
||
| 506 | RetVal, R.data(), R.size())) |
||
| 507 | SDR(std::move(Err), std::move(RetVal)); |
||
| 508 | |||
| 509 | SDR(Error::success(), std::move(RetVal)); |
||
| 510 | }; |
||
| 511 | |||
| 512 | Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size()); |
||
| 513 | } |
||
| 514 | |||
| 515 | /// Handle a call to a wrapper function. |
||
| 516 | template <typename HandlerT> |
||
| 517 | static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, |
||
| 518 | HandlerT &&Handler) { |
||
| 519 | using WFHH = |
||
| 520 | detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, |
||
| 521 | ResultSerializer, SPSTagTs...>; |
||
| 522 | return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); |
||
| 523 | } |
||
| 524 | |||
| 525 | /// Handle a call to an async wrapper function. |
||
| 526 | template <typename HandlerT, typename SendResultT> |
||
| 527 | static void handleAsync(const char *ArgData, size_t ArgSize, |
||
| 528 | HandlerT &&Handler, SendResultT &&SendResult) { |
||
| 529 | using WFAHH = detail::WrapperFunctionAsyncHandlerHelper< |
||
| 530 | std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>; |
||
| 531 | WFAHH::applyAsync(std::forward<HandlerT>(Handler), |
||
| 532 | std::forward<SendResultT>(SendResult), ArgData, ArgSize); |
||
| 533 | } |
||
| 534 | |||
| 535 | private: |
||
| 536 | template <typename T> static const T &makeSerializable(const T &Value) { |
||
| 537 | return Value; |
||
| 538 | } |
||
| 539 | |||
| 540 | static detail::SPSSerializableError makeSerializable(Error Err) { |
||
| 541 | return detail::toSPSSerializable(std::move(Err)); |
||
| 542 | } |
||
| 543 | |||
| 544 | template <typename T> |
||
| 545 | static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { |
||
| 546 | return detail::toSPSSerializable(std::move(E)); |
||
| 547 | } |
||
| 548 | }; |
||
| 549 | |||
| 550 | template <typename... SPSTagTs> |
||
| 551 | class WrapperFunction<void(SPSTagTs...)> |
||
| 552 | : private WrapperFunction<SPSEmpty(SPSTagTs...)> { |
||
| 553 | |||
| 554 | public: |
||
| 555 | template <typename CallerFn, typename... ArgTs> |
||
| 556 | static Error call(const CallerFn &Caller, const ArgTs &...Args) { |
||
| 557 | SPSEmpty BE; |
||
| 558 | return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...); |
||
| 559 | } |
||
| 560 | |||
| 561 | template <typename AsyncCallerFn, typename SendDeserializedResultFn, |
||
| 562 | typename... ArgTs> |
||
| 563 | static void callAsync(AsyncCallerFn &&Caller, |
||
| 564 | SendDeserializedResultFn &&SendDeserializedResult, |
||
| 565 | const ArgTs &...Args) { |
||
| 566 | WrapperFunction<SPSEmpty(SPSTagTs...)>::callAsync( |
||
| 567 | std::forward<AsyncCallerFn>(Caller), |
||
| 568 | [SDR = std::move(SendDeserializedResult)](Error SerializeErr, |
||
| 569 | SPSEmpty E) mutable { |
||
| 570 | SDR(std::move(SerializeErr)); |
||
| 571 | }, |
||
| 572 | Args...); |
||
| 573 | } |
||
| 574 | |||
| 575 | using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; |
||
| 576 | using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync; |
||
| 577 | }; |
||
| 578 | |||
| 579 | /// A function object that takes an ExecutorAddr as its first argument, |
||
| 580 | /// casts that address to a ClassT*, then calls the given method on that |
||
| 581 | /// pointer passing in the remaining function arguments. This utility |
||
| 582 | /// removes some of the boilerplate from writing wrappers for method calls. |
||
| 583 | /// |
||
| 584 | /// @code{.cpp} |
||
| 585 | /// class MyClass { |
||
| 586 | /// public: |
||
| 587 | /// void myMethod(uint32_t, bool) { ... } |
||
| 588 | /// }; |
||
| 589 | /// |
||
| 590 | /// // SPS Method signature -- note MyClass object address as first argument. |
||
| 591 | /// using SPSMyMethodWrapperSignature = |
||
| 592 | /// SPSTuple<SPSExecutorAddr, uint32_t, bool>; |
||
| 593 | /// |
||
| 594 | /// WrapperFunctionResult |
||
| 595 | /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { |
||
| 596 | /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( |
||
| 597 | /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); |
||
| 598 | /// } |
||
| 599 | /// @endcode |
||
| 600 | /// |
||
| 601 | template <typename RetT, typename ClassT, typename... ArgTs> |
||
| 602 | class MethodWrapperHandler { |
||
| 603 | public: |
||
| 604 | using MethodT = RetT (ClassT::*)(ArgTs...); |
||
| 605 | MethodWrapperHandler(MethodT M) : M(M) {} |
||
| 606 | RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { |
||
| 607 | return (ObjAddr.toPtr<ClassT*>()->*M)(std::forward<ArgTs>(Args)...); |
||
| 608 | } |
||
| 609 | |||
| 610 | private: |
||
| 611 | MethodT M; |
||
| 612 | }; |
||
| 613 | |||
| 614 | /// Create a MethodWrapperHandler object from the given method pointer. |
||
| 615 | template <typename RetT, typename ClassT, typename... ArgTs> |
||
| 616 | MethodWrapperHandler<RetT, ClassT, ArgTs...> |
||
| 617 | makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { |
||
| 618 | return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); |
||
| 619 | } |
||
| 620 | |||
| 621 | /// Represents a serialized wrapper function call. |
||
| 622 | /// Serializing calls themselves allows us to batch them: We can make one |
||
| 623 | /// "run-wrapper-functions" utility and send it a list of calls to run. |
||
| 624 | /// |
||
| 625 | /// The motivating use-case for this API is JITLink allocation actions, where |
||
| 626 | /// we want to run multiple functions to finalize linked memory without having |
||
| 627 | /// to make separate IPC calls for each one. |
||
| 628 | class WrapperFunctionCall { |
||
| 629 | public: |
||
| 630 | using ArgDataBufferType = SmallVector<char, 24>; |
||
| 631 | |||
| 632 | /// Create a WrapperFunctionCall using the given SPS serializer to serialize |
||
| 633 | /// the arguments. |
||
| 634 | template <typename SPSSerializer, typename... ArgTs> |
||
| 635 | static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr, |
||
| 636 | const ArgTs &...Args) { |
||
| 637 | ArgDataBufferType ArgData; |
||
| 638 | ArgData.resize(SPSSerializer::size(Args...)); |
||
| 639 | SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), |
||
| 640 | ArgData.size()); |
||
| 641 | if (SPSSerializer::serialize(OB, Args...)) |
||
| 642 | return WrapperFunctionCall(FnAddr, std::move(ArgData)); |
||
| 643 | return make_error<StringError>("Cannot serialize arguments for " |
||
| 644 | "AllocActionCall", |
||
| 645 | inconvertibleErrorCode()); |
||
| 646 | } |
||
| 647 | |||
| 648 | WrapperFunctionCall() = default; |
||
| 649 | |||
| 650 | /// Create a WrapperFunctionCall from a target function and arg buffer. |
||
| 651 | WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) |
||
| 652 | : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} |
||
| 653 | |||
| 654 | /// Returns the address to be called. |
||
| 655 | const ExecutorAddr &getCallee() const { return FnAddr; } |
||
| 656 | |||
| 657 | /// Returns the argument data. |
||
| 658 | const ArgDataBufferType &getArgData() const { return ArgData; } |
||
| 659 | |||
| 660 | /// WrapperFunctionCalls convert to true if the callee is non-null. |
||
| 661 | explicit operator bool() const { return !!FnAddr; } |
||
| 662 | |||
| 663 | /// Run call returning raw WrapperFunctionResult. |
||
| 664 | shared::WrapperFunctionResult run() const { |
||
| 665 | using FnTy = |
||
| 666 | shared::CWrapperFunctionResult(const char *ArgData, size_t ArgSize); |
||
| 667 | return shared::WrapperFunctionResult( |
||
| 668 | FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size())); |
||
| 669 | } |
||
| 670 | |||
| 671 | /// Run call and deserialize result using SPS. |
||
| 672 | template <typename SPSRetT, typename RetT> |
||
| 673 | std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error> |
||
| 674 | runWithSPSRet(RetT &RetVal) const { |
||
| 675 | auto WFR = run(); |
||
| 676 | if (const char *ErrMsg = WFR.getOutOfBandError()) |
||
| 677 | return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); |
||
| 678 | shared::SPSInputBuffer IB(WFR.data(), WFR.size()); |
||
| 679 | if (!shared::SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) |
||
| 680 | return make_error<StringError>("Could not deserialize result from " |
||
| 681 | "serialized wrapper function call", |
||
| 682 | inconvertibleErrorCode()); |
||
| 683 | return Error::success(); |
||
| 684 | } |
||
| 685 | |||
| 686 | /// Overload for SPS functions returning void. |
||
| 687 | template <typename SPSRetT> |
||
| 688 | std::enable_if_t<std::is_same<SPSRetT, void>::value, Error> |
||
| 689 | runWithSPSRet() const { |
||
| 690 | shared::SPSEmpty E; |
||
| 691 | return runWithSPSRet<shared::SPSEmpty>(E); |
||
| 692 | } |
||
| 693 | |||
| 694 | /// Run call and deserialize an SPSError result. SPSError returns and |
||
| 695 | /// deserialization failures are merged into the returned error. |
||
| 696 | Error runWithSPSRetErrorMerged() const { |
||
| 697 | detail::SPSSerializableError RetErr; |
||
| 698 | if (auto Err = runWithSPSRet<SPSError>(RetErr)) |
||
| 699 | return Err; |
||
| 700 | return detail::fromSPSSerializable(std::move(RetErr)); |
||
| 701 | } |
||
| 702 | |||
| 703 | private: |
||
| 704 | orc::ExecutorAddr FnAddr; |
||
| 705 | ArgDataBufferType ArgData; |
||
| 706 | }; |
||
| 707 | |||
| 708 | using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>; |
||
| 709 | |||
| 710 | template <> |
||
| 711 | class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { |
||
| 712 | public: |
||
| 713 | static size_t size(const WrapperFunctionCall &WFC) { |
||
| 714 | return SPSWrapperFunctionCall::AsArgList::size(WFC.getCallee(), |
||
| 715 | WFC.getArgData()); |
||
| 716 | } |
||
| 717 | |||
| 718 | static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { |
||
| 719 | return SPSWrapperFunctionCall::AsArgList::serialize(OB, WFC.getCallee(), |
||
| 720 | WFC.getArgData()); |
||
| 721 | } |
||
| 722 | |||
| 723 | static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { |
||
| 724 | ExecutorAddr FnAddr; |
||
| 725 | WrapperFunctionCall::ArgDataBufferType ArgData; |
||
| 726 | if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) |
||
| 727 | return false; |
||
| 728 | WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); |
||
| 729 | return true; |
||
| 730 | } |
||
| 731 | }; |
||
| 732 | |||
| 733 | } // end namespace shared |
||
| 734 | } // end namespace orc |
||
| 735 | } // end namespace llvm |
||
| 736 | |||
| 737 | #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H |