Details | Last modification | View Log | RSS feed
Rev | Author | Line No. | Line |
---|---|---|---|
14 | pmbaty | 1 | function(tf_get_absolute_path path base final_path) |
2 | if (IS_ABSOLUTE ${path}) |
||
3 | set(${final_path} ${path} PARENT_SCOPE) |
||
4 | else() |
||
5 | set(${final_path} ${base}/${path} PARENT_SCOPE) |
||
6 | endif() |
||
7 | endfunction() |
||
8 | |||
9 | function(tf_get_model model final_path) |
||
10 | string(FIND ${model} "http:" pos_http) |
||
11 | string(FIND ${model} "https:" pos_https) |
||
12 | if (${pos_http} EQUAL 0 OR ${pos_https} EQUAL 0) |
||
13 | message("Downloading model " ${model}) |
||
14 | string(FIND ${model} "/" fname_start REVERSE) |
||
15 | math(EXPR fname_start "${fname_start}+1") |
||
16 | string(SUBSTRING ${model} ${fname_start}+1 -1 fname) |
||
17 | message("Model archive: " ${fname}) |
||
18 | file(DOWNLOAD ${model} ${CMAKE_CURRENT_BINARY_DIR}/${fname}) |
||
19 | file(ARCHIVE_EXTRACT INPUT |
||
20 | ${CMAKE_CURRENT_BINARY_DIR}/${fname} |
||
21 | DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/${fname}_model) |
||
22 | set(${final_path} ${CMAKE_CURRENT_BINARY_DIR}/${fname}_model/model PARENT_SCOPE) |
||
23 | else() |
||
24 | tf_get_absolute_path(${model} ${CMAKE_CURRENT_BINARY_DIR} model_path) |
||
25 | set(${final_path} ${model_path} PARENT_SCOPE) |
||
26 | endif() |
||
27 | endfunction() |
||
28 | |||
29 | # Generate a mock model for tests. |
||
30 | function(generate_mock_model generator output) |
||
31 | tf_get_absolute_path(${generator} ${CMAKE_CURRENT_SOURCE_DIR} generator_absolute_path) |
||
32 | tf_get_absolute_path(${output} ${CMAKE_CURRENT_BINARY_DIR} output_absolute_path) |
||
33 | message(WARNING "Autogenerated mock models should not be used in production builds.") |
||
34 | execute_process(COMMAND ${Python3_EXECUTABLE} |
||
35 | ${generator_absolute_path} |
||
36 | ${output_absolute_path} |
||
37 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} |
||
38 | ) |
||
39 | endfunction() |
||
40 | |||
41 | # Run the tensorflow compiler (saved_model_cli) on the saved model in the |
||
42 | # ${model} directory, looking for the ${tag_set} tag set, and the SignatureDef |
||
43 | # ${signature_def_key}. |
||
44 | # Produce a pair of files called ${fname}.h and ${fname}.o in the |
||
45 | # ${CMAKE_CURRENT_BINARY_DIR}. The generated header will define a C++ class |
||
46 | # called ${cpp_class} - which may be a namespace-qualified class name. |
||
47 | function(tf_compile model tag_set signature_def_key fname cpp_class hdr_file obj_file) |
||
48 | tf_get_absolute_path(${model} ${CMAKE_CURRENT_BINARY_DIR} LLVM_ML_MODELS_ABSOLUTE) |
||
49 | message("Using model at " ${LLVM_ML_MODELS_ABSOLUTE}) |
||
50 | add_custom_command(OUTPUT ${obj_file} ${hdr_file} |
||
51 | COMMAND ${TENSORFLOW_AOT_COMPILER} aot_compile_cpu |
||
52 | --multithreading false |
||
53 | --dir ${LLVM_ML_MODELS_ABSOLUTE} |
||
54 | --tag_set ${tag_set} |
||
55 | --signature_def_key ${signature_def_key} |
||
56 | --output_prefix ${prefix} |
||
57 | --cpp_class ${cpp_class} |
||
58 | --target_triple ${LLVM_HOST_TRIPLE} |
||
59 | ) |
||
60 | |||
61 | # Aggregate the objects so that results of different tf_compile calls may be |
||
62 | # grouped into one target. |
||
63 | set(GENERATED_OBJS ${GENERATED_OBJS} ${obj_file} PARENT_SCOPE) |
||
64 | set_source_files_properties(${obj_file} PROPERTIES |
||
65 | GENERATED 1 EXTERNAL_OBJECT 1) |
||
66 | |||
67 | set(GENERATED_HEADERS ${GENERATED_HEADERS} ${hdr_file} PARENT_SCOPE) |
||
68 | set_source_files_properties(${hdr_file} PROPERTIES |
||
69 | GENERATED 1) |
||
70 | |||
71 | endfunction() |
||
72 | |||
73 | function(tf_find_and_compile model default_url default_path test_model_generator tag_set signature_def_key fname cpp_class) |
||
74 | set(prefix ${CMAKE_CURRENT_BINARY_DIR}/${fname}) |
||
75 | set(obj_file ${prefix}.o) |
||
76 | set(hdr_file ${prefix}.h) |
||
77 | string(TOUPPER ${fname} fname_allcaps) |
||
78 | set(override_header ${LLVM_OVERRIDE_MODEL_HEADER_${fname_allcaps}}) |
||
79 | set(override_object ${LLVM_OVERRIDE_MODEL_OBJECT_${fname_allcaps}}) |
||
80 | # If the user specified overrides, that indicates intent to use AOT and we |
||
81 | # don't care what the model path is |
||
82 | if (EXISTS "${override_header}" AND EXISTS "${override_object}") |
||
83 | configure_file(${override_header} ${hdr_file} COPYONLY) |
||
84 | configure_file(${override_object} ${obj_file} COPYONLY) |
||
85 | message(STATUS "Using provided header " ${hdr_file} " and object " ${obj_file} " |
||
86 | files for model " ${fname}) |
||
87 | set(GENERATED_OBJS ${GENERATED_OBJS} ${obj_file}) |
||
88 | set(GENERATED_HEADERS ${GENERATED_HEADERS} ${hdr_file}) |
||
89 | elseif("${model}" STREQUAL "none") |
||
90 | message(STATUS "Will skip enabling mlgo for ${fname}") |
||
91 | return() |
||
92 | else() |
||
93 | if ("${model}" STREQUAL "download") |
||
94 | # Crash if the user wants to download a model but a URL is set to "TO_BE_UPDATED" |
||
95 | if ("${default_url}" STREQUAL "<UNSPECIFIED>") |
||
96 | message(FATAL_ERROR "Model path was set to 'download' but there is no" |
||
97 | " model url currently specified in cmake. You can generate a model" |
||
98 | " using, for example, the tools at http://github.com/google/ml-compiler-opt." |
||
99 | " Some reference models are also periodically released there.") |
||
100 | endif() |
||
101 | |||
102 | set(model ${default_url}) |
||
103 | endif() |
||
104 | |||
105 | if ("${model}" STREQUAL "autogenerate") |
||
106 | set(model ${default_path}-autogenerated) |
||
107 | generate_mock_model(${test_model_generator} ${model}) |
||
108 | endif() |
||
109 | |||
110 | tf_get_model(${model} LLVM_ML_MODELS_ABSOLUTE) |
||
111 | tf_compile(${LLVM_ML_MODELS_ABSOLUTE} ${tag_set} ${signature_def_key} ${fname} ${cpp_class} ${hdr_file} ${obj_file}) |
||
112 | endif() |
||
113 | |||
114 | set(GeneratedMLSources ${GeneratedMLSources} ${GENERATED_OBJS} ${GENERATED_HEADERS} PARENT_SCOPE) |
||
115 | set(MLDeps ${MLDeps} tf_xla_runtime PARENT_SCOPE) |
||
116 | set(MLLinkDeps ${MLLinkDeps} tf_xla_runtime PARENT_SCOPE) |
||
117 | add_compile_definitions(LLVM_HAVE_TF_AOT_${fname_allcaps}) |
||
118 | endfunction() |