Skip to content

Commit

Permalink
refactor!: pkg_model -> pkg_asset
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Feb 21, 2021
1 parent 9704a33 commit 20083ff
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
4 changes: 2 additions & 2 deletions C/mediapipe_api/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("@rules_pkg//:pkg.bzl", "pkg_zip")
load("//mediapipe_api:import_model.bzl", "pkg_model")
load("//mediapipe_api:import_model.bzl", "pkg_asset")

cc_library(
name = "mediapipe_c",
Expand Down Expand Up @@ -129,7 +129,7 @@ pkg_zip(
# TODO: keep directory structure
)

pkg_model(
pkg_asset(
name = "mediapipe_models",
srcs = [
"@com_google_mediapipe//mediapipe/models:face_detection_front.tflite",
Expand Down
57 changes: 34 additions & 23 deletions C/mediapipe_api/import_model.bzl
Original file line number Diff line number Diff line change
@@ -1,37 +1,48 @@
load("@rules_pkg//:pkg.bzl", "pkg_zip")

def pkg_model(name, srcs = [], **kwargs):
"""Package MediaPipe models
This task renames model files so that they can be added to an AssetBundle (e.g. x.tflte -> x.bytes) and zip them.
def pkg_asset(name, srcs = [], **kwargs):
"""Package MediaPipe assets
This task renames asset files so that they can be added to an AssetBundle (e.g. x.tflte -> x.bytes) and zip them.
Args:
name: the name of the output zip file
srcs: files to be packaged
"""

for src in srcs:
_rename_file(src)
rename_target = "normalize_%s_exts" % name
_normalize_exts(name = rename_target, srcs = srcs)

pkg_zip(
name = name,
srcs = [_export_file_path(src) for src in srcs],
srcs = [":" + rename_target],
**kwargs,
)

def _rename_file(src):
export_file = _export_file_path(src)

native.genrule(
name = "export_" + export_file,
srcs = [src],
outs = [export_file],
cmd = "cp $< $@",
)

def _export_file_path(src):
[prefix, base_name] = src.split(":") # src must contain one colon
name_arr = base_name.split(".")
[name, ext] = [base_name, ""] if len(name_arr) == 1 else ["".join(name_arr[:-1]), name_arr[-1]]
export_file_ext = "bytes" if ext == "tflite" else "txt"

return "{}.{}".format(name, export_file_ext)
def _normalize_exts_impl(ctx):
output_files = []

for src in ctx.files.srcs:
if src.extension in ctx.attr.bytes_exts:
dest = ctx.actions.declare_file(src.path[:-1 * len(src.extension)] + "bytes")
ctx.actions.run_shell(
inputs = [src],
outputs = [dest],
arguments = [src.path, dest.path],
command = "test $1 != $2 && cp $1 $2",
progress_message = "Copying %s to %s...".format(src.path, dest.path),
)
output_files.append(dest)
else:
output_files.append(src)

return [
DefaultInfo(files = depset(output_files)),
]

_normalize_exts = rule(
implementation = _normalize_exts_impl,
attrs = {
"srcs": attr.label_list(allow_files = True),
"bytes_exts": attr.string_list(default = ["jpg", "tflite", "uuu"]),
},
)

0 comments on commit 20083ff

Please sign in to comment.