-
-
Notifications
You must be signed in to change notification settings - Fork 450
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
36 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,48 @@ | ||
load("@rules_pkg//:pkg.bzl", "pkg_zip") | ||
|
||
def pkg_model(name, srcs = [], **kwargs): | ||
"""Package MediaPipe models | ||
This task renames model files so that they can be added to an AssetBundle (e.g. x.tflte -> x.bytes) and zip them. | ||
def pkg_asset(name, srcs = [], **kwargs): | ||
"""Package MediaPipe assets | ||
This task renames asset files so that they can be added to an AssetBundle (e.g. x.tflte -> x.bytes) and zip them. | ||
Args: | ||
name: the name of the output zip file | ||
srcs: files to be packaged | ||
""" | ||
|
||
for src in srcs: | ||
_rename_file(src) | ||
rename_target = "normalize_%s_exts" % name | ||
_normalize_exts(name = rename_target, srcs = srcs) | ||
|
||
pkg_zip( | ||
name = name, | ||
srcs = [_export_file_path(src) for src in srcs], | ||
srcs = [":" + rename_target], | ||
**kwargs, | ||
) | ||
|
||
def _rename_file(src): | ||
export_file = _export_file_path(src) | ||
|
||
native.genrule( | ||
name = "export_" + export_file, | ||
srcs = [src], | ||
outs = [export_file], | ||
cmd = "cp $< $@", | ||
) | ||
|
||
def _export_file_path(src): | ||
[prefix, base_name] = src.split(":") # src must contain one colon | ||
name_arr = base_name.split(".") | ||
[name, ext] = [base_name, ""] if len(name_arr) == 1 else ["".join(name_arr[:-1]), name_arr[-1]] | ||
export_file_ext = "bytes" if ext == "tflite" else "txt" | ||
|
||
return "{}.{}".format(name, export_file_ext) | ||
def _normalize_exts_impl(ctx): | ||
output_files = [] | ||
|
||
for src in ctx.files.srcs: | ||
if src.extension in ctx.attr.bytes_exts: | ||
dest = ctx.actions.declare_file(src.path[:-1 * len(src.extension)] + "bytes") | ||
ctx.actions.run_shell( | ||
inputs = [src], | ||
outputs = [dest], | ||
arguments = [src.path, dest.path], | ||
command = "test $1 != $2 && cp $1 $2", | ||
progress_message = "Copying %s to %s...".format(src.path, dest.path), | ||
) | ||
output_files.append(dest) | ||
else: | ||
output_files.append(src) | ||
|
||
return [ | ||
DefaultInfo(files = depset(output_files)), | ||
] | ||
|
||
_normalize_exts = rule( | ||
implementation = _normalize_exts_impl, | ||
attrs = { | ||
"srcs": attr.label_list(allow_files = True), | ||
"bytes_exts": attr.string_list(default = ["jpg", "tflite", "uuu"]), | ||
}, | ||
) |