Skip to content

Commit

Permalink
Ensure exporter creates valid python name for opset vars (#1051)
Browse files Browse the repository at this point in the history
The existing exporter creates invalid python identifiers like
"onnxscript.atenlib1" when generating variable names for opsets. Change
this to create a name like "onnxscript_atenlib1".

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam committed Sep 7, 2023
1 parent 8af8161 commit 5eebc89
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 27 deletions.
78 changes: 54 additions & 24 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# --------------------------------------------------------------------------
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Sequence

import numpy
import onnx
from onnx import FunctionProto, ModelProto, TensorProto, ValueInfoProto
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
from onnx.helper import make_node

import onnxscript.onnx_types
Expand All @@ -23,11 +23,9 @@
{% if unique_types %}
from onnxscript.onnx_types import {{ ", ".join(unique_types) }}
{%- endif %}
from onnxscript.onnx_opset import opset{{ opsets[''] }}
{% for domain, version in unique_function_domain_version: %}
{{ domain }}{{ version }} = Opset("{{ domain }}", {{ version }}){% endfor %}
{{translate_opset_imports_of(main_model)}}
{% for domain, name, fct in functions: %}
@script({{ domain }}1)
@script({{make_opset_name(domain, 1)}})
def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{
translate_function_signature(fct['proto'])}}
{% if fct['proto'].doc_string %}"""
Expand Down Expand Up @@ -201,22 +199,6 @@ def _attribute_value(attr: onnx.AttributeProto):
raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.")


def _python_make_node_name(domain, version, name, node=False):
name = _rename_variable(name)
if node:
if version is None:
version = 1
if not isinstance(version, int):
raise TypeError(
f"version must be an integer not {version!r} for domain={domain!r} "
f"and name={name!r}."
)
if domain == "":
return f"opset{version}.{name}"
return f"{domain.replace('.', '_')}{version}.{name}"
return name


class Exporter:
"""Class used for recursive traversal of Proto structures."""

Expand All @@ -230,6 +212,28 @@ def _rename_variable_s(self, name):
"""Renames all names equal to a python keyword."""
return str(self._rename_variable(name))

def _rename_domain(self, domain: str) -> str:
if domain == "":
return "opset"
return domain.replace(".", "_")

def make_opset_name(self, domain, version):
return f"{self._rename_domain(domain)}{version}"

def _python_make_node_name(self, domain, version, name, node=False):
name = _rename_variable(name)
if node:
if version is None:
version = 1
if not isinstance(version, int):
raise TypeError(
f"version must be an integer not {version!r} for domain={domain!r} "
f"and name={name!r}."
)
opset = self.make_opset_name(domain, version)
return f"{opset}.{name}"
return name

def _python_make_node_graph(self, graph, opsets, indent=0, output_names=None):
"""Translates a GraphProto into python."""
code = []
Expand Down Expand Up @@ -403,7 +407,7 @@ def _python_make_node(self, onnx_node, opsets, indent=0):
f"{sindent}{self._rename_variable(node.output[0])} = "
f"{(f' {ops[node.op_type]} ').join(map(self.lookup, node.input))}"
)
name = _python_make_node_name(
name = self._python_make_node_name(
node.domain, opsets[node.domain], node.op_type, node=True
)
attributes_str = self._python_make_node_make_attribute_str(node)
Expand All @@ -428,6 +432,29 @@ def _python_make_node(self, onnx_node, opsets, indent=0):
]
return "".join(text)

def translate_opset_import(self, domain: str, version: int) -> str:
if domain in {"", "ai.onnx"}:
return f"from onnxscript.onnx_opset import opset{version}\n"
else:
varname = self.make_opset_name(domain, version)
return f"{varname} = Opset('{domain}', {version})\n"

def translate_opset_imports(self, opset_imports: Sequence[onnx.OperatorSetIdProto]) -> str:
return "".join(
[self.translate_opset_import(x.domain, x.version) for x in opset_imports]
)

def translate_opset_imports_of(
self, proto: ModelProto | FunctionProto | GraphProto
) -> str:
if hasattr(proto, "opset_import"):
text = self.translate_opset_imports(proto.opset_import)
if isinstance(proto, FunctionProto):
if not any(x.domain == proto.domain for x in proto.opset_import):
text += self.translate_opset_import(proto.domain, 1)
return text
return ""

def translate_function_signature(self, funproto: onnx.FunctionProto) -> str:
"""Generate signature for FunctionProto."""
type_map = _attribute_param_types(funproto)
Expand Down Expand Up @@ -522,10 +549,13 @@ def rename_variable(name):
"main_model": model_onnx,
"python_make_node": exporter._python_make_node, # pylint: disable=protected-access
"python_make_node_graph": exporter._python_make_node_graph, # pylint: disable=protected-access
"python_make_node_name": _python_make_node_name, # pylint: disable=protected-access
"python_make_node_name": exporter._python_make_node_name, # pylint: disable=protected-access
"rename": rename_variable,
"translate_sig": _translate_signature,
"translate_function_signature": exporter.translate_function_signature,
"translate_opset_imports_of": exporter.translate_opset_imports_of,
"hasattr": hasattr,
"make_opset_name": exporter.make_opset_name,
}

# opset
Expand Down
18 changes: 15 additions & 3 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import onnxscript
import onnxscript.testing
import onnxscript.values
from onnxscript.backend import onnx_backend, onnx_export
from onnxscript.tests.models import type_double

Expand Down Expand Up @@ -136,7 +137,8 @@ class TestOnnxBackEnd(unittest.TestCase):
test_folder = root_folder / "tests" / "onnx_backend_test_code"
temp_folder = root_folder / "tests" / "export"

def _round_trip_check(self, proto, **export_options):
def _round_trip_check(self, script_function, **export_options):
proto = script_function.to_function_proto()
code = onnx_export.export2python(proto, **export_options)
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)
result_proto = map[proto.name]
Expand All @@ -150,8 +152,18 @@ def test_attr_ref(self):
def fun_with_attr_param(X, dtype: int):
return op.Cast(X, to=dtype)

fun_proto = fun_with_attr_param.to_function_proto()
self._round_trip_check(fun_proto)
self._round_trip_check(fun_with_attr_param)

def test_qualified_domain(self):
"""Test use of qualified domain name."""
op = onnxscript.opset17
custom_opset = onnxscript.values.Opset("my.domain.com", 1)

@onnxscript.script(custom_opset)
def twice(X):
return op.Add(X, X)

self._round_trip_check(twice)

def test_export2python(self):
proto = type_double.double_abs_subgraph.to_model_proto()
Expand Down

0 comments on commit 5eebc89

Please sign in to comment.