Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure exporter creates valid python name for opset vars #1051

Merged
merged 6 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# --------------------------------------------------------------------------
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Sequence

import numpy
import onnx
from onnx import FunctionProto, ModelProto, TensorProto, ValueInfoProto
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
from onnx.helper import make_node

import onnxscript.onnx_types
Expand All @@ -23,11 +23,9 @@
{% if unique_types %}
from onnxscript.onnx_types import {{ ", ".join(unique_types) }}
{%- endif %}
from onnxscript.onnx_opset import opset{{ opsets[''] }}
{% for domain, version in unique_function_domain_version: %}
{{ domain }}{{ version }} = Opset("{{ domain }}", {{ version }}){% endfor %}
{{translate_opset_imports_of(main_model)}}
{% for domain, name, fct in functions: %}
@script({{ domain }}1)
@script({{make_opset_name(domain, 1)}})
def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{
translate_function_signature(fct['proto'])}}
{% if fct['proto'].doc_string %}"""
Expand Down Expand Up @@ -201,22 +199,6 @@
raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.")


def _python_make_node_name(domain, version, name, node=False):
name = _rename_variable(name)
if node:
if version is None:
version = 1
if not isinstance(version, int):
raise TypeError(
f"version must be an integer not {version!r} for domain={domain!r} "
f"and name={name!r}."
)
if domain == "":
return f"opset{version}.{name}"
return f"{domain.replace('.', '_')}{version}.{name}"
return name


class Exporter:
"""Class used for recursive traversal of Proto structures."""

Expand All @@ -230,6 +212,28 @@
"""Renames all names equal to a python keyword."""
return str(self._rename_variable(name))

def _rename_domain(self, domain: str) -> str:
if domain == "":
return "opset"
return domain.replace(".", "_")
gramalingam marked this conversation as resolved.
Show resolved Hide resolved

def make_opset_name(self, domain, version):
return f"{self._rename_domain(domain)}{version}"
gramalingam marked this conversation as resolved.
Show resolved Hide resolved

def _python_make_node_name(self, domain, version, name, node=False):
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
name = _rename_variable(name)
if node:
if version is None:
version = 1

Check warning on line 227 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L227

Added line #L227 was not covered by tests
if not isinstance(version, int):
raise TypeError(

Check warning on line 229 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L229

Added line #L229 was not covered by tests
f"version must be an integer not {version!r} for domain={domain!r} "
f"and name={name!r}."
)
opset = self.make_opset_name(domain, version)
return f"{opset}.{name}"
return name

def _python_make_node_graph(self, graph, opsets, indent=0, output_names=None):
"""Translates a GraphProto into python."""
code = []
Expand Down Expand Up @@ -403,7 +407,7 @@
f"{sindent}{self._rename_variable(node.output[0])} = "
f"{(f' {ops[node.op_type]} ').join(map(self.lookup, node.input))}"
)
name = _python_make_node_name(
name = self._python_make_node_name(
node.domain, opsets[node.domain], node.op_type, node=True
)
attributes_str = self._python_make_node_make_attribute_str(node)
Expand All @@ -428,6 +432,29 @@
]
return "".join(text)

def translate_opset_import(self, domain: str, version: int) -> str:
if domain in {"", "ai.onnx"}:
return f"from onnxscript.onnx_opset import opset{version}\n"
else:
varname = self.make_opset_name(domain, version)
return f"{varname} = Opset('{domain}', {version})\n"

def translate_opset_imports(self, opset_imports: Sequence[onnx.OperatorSetIdProto]) -> str:
return "".join(
[self.translate_opset_import(x.domain, x.version) for x in opset_imports]
)

def translate_opset_imports_of(
self, proto: ModelProto | FunctionProto | GraphProto
) -> str:
if hasattr(proto, "opset_import"):
text = self.translate_opset_imports(proto.opset_import)
if isinstance(proto, FunctionProto):
if not any(x.domain == proto.domain for x in proto.opset_import):
text += self.translate_opset_import(proto.domain, 1)
return text
return ""

Check warning on line 456 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L456

Added line #L456 was not covered by tests

def translate_function_signature(self, funproto: onnx.FunctionProto) -> str:
"""Generate signature for FunctionProto."""
type_map = _attribute_param_types(funproto)
Expand Down Expand Up @@ -522,10 +549,13 @@
"main_model": model_onnx,
"python_make_node": exporter._python_make_node, # pylint: disable=protected-access
"python_make_node_graph": exporter._python_make_node_graph, # pylint: disable=protected-access
"python_make_node_name": _python_make_node_name, # pylint: disable=protected-access
"python_make_node_name": exporter._python_make_node_name, # pylint: disable=protected-access
"rename": rename_variable,
"translate_sig": _translate_signature,
"translate_function_signature": exporter.translate_function_signature,
"translate_opset_imports_of": exporter.translate_opset_imports_of,
"hasattr": hasattr,
"make_opset_name": exporter.make_opset_name,
}

# opset
Expand Down
18 changes: 15 additions & 3 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import onnxscript
import onnxscript.testing
import onnxscript.values
from onnxscript.backend import onnx_backend, onnx_export
from onnxscript.tests.models import type_double

Expand Down Expand Up @@ -136,7 +137,8 @@
test_folder = root_folder / "tests" / "onnx_backend_test_code"
temp_folder = root_folder / "tests" / "export"

def _round_trip_check(self, proto, **export_options):
def _round_trip_check(self, script_function, **export_options):
proto = script_function.to_function_proto()
code = onnx_export.export2python(proto, **export_options)
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)
result_proto = map[proto.name]
Expand All @@ -150,8 +152,18 @@
def fun_with_attr_param(X, dtype: int):
return op.Cast(X, to=dtype)

fun_proto = fun_with_attr_param.to_function_proto()
self._round_trip_check(fun_proto)
self._round_trip_check(fun_with_attr_param)

def test_qualified_domain(self):
"""Test use of qualified domain name."""
op = onnxscript.opset17
custom_opset = onnxscript.values.Opset("my.domain.com", 1)

@onnxscript.script(custom_opset)
def twice(X):
return op.Add(X, X)

Check warning on line 164 in onnxscript/backend/onnx_export_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export_test.py#L164

Added line #L164 was not covered by tests

self._round_trip_check(twice)

def test_export2python(self):
proto = type_double.double_abs_subgraph.to_model_proto()
Expand Down
Loading