Skip to content

Commit

Permalink
Implement a ONNX to ONNX Script code generator based on libcst
Browse files Browse the repository at this point in the history
- Adds some general codegen utilities based on libcst

- Implements an ONNX to ONNX Script generator: the base converter
  produces ONNX Script that is very 1:1 with the structure of ONNX,
  and transformers are implemented to raise the generated code to
  more idiomatic Python that ONNX Script supports; this commit
  provides support for raising to Python binary operators and raising
  Constant/make_tensor to supported Python constants; more transformers
  need to be implemented, but this commit can be used as a guide.

- Adds a new top-level command line interface, allowing the code
  generator to be invoked:

  python -m onnxscript convert model.onnx
  • Loading branch information
abock committed Jul 13, 2023
1 parent 97604f6 commit 99a2f51
Show file tree
Hide file tree
Showing 3 changed files with 808 additions and 0 deletions.
60 changes: 60 additions & 0 deletions onnxscript/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from __future__ import annotations

Check warning on line 6 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L6

Added line #L6 was not covered by tests

import argparse
from pathlib import Path
from typing import BinaryIO, Protocol

Check warning on line 10 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L8-L10

Added lines #L8 - L10 were not covered by tests

from onnxscript.codeanalysis import onnx_to_onnxscript

Check warning on line 12 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L12

Added line #L12 was not covered by tests


class ConvertCommandArgs(Protocol):
onnx_model_reader: BinaryIO
onnxscript_writer: BinaryIO

Check warning on line 17 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L15-L17

Added lines #L15 - L17 were not covered by tests


def convert_command(args: ConvertCommandArgs):
args.onnxscript_writer.write(

Check warning on line 21 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L20-L21

Added lines #L20 - L21 were not covered by tests
onnx_to_onnxscript.Driver(args.onnx_model_reader).to_python_code(
None
if args.onnxscript_writer.name == "<stdout>"
else Path(args.onnxscript_writer.name)
)
)


def main():
parser = argparse.ArgumentParser(prog="onnxscript")
subparsers = parser.add_subparsers(required=True)

Check warning on line 32 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L30-L32

Added lines #L30 - L32 were not covered by tests

parser_convert = subparsers.add_parser(

Check warning on line 34 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L34

Added line #L34 was not covered by tests
"convert",
help="Convert an ONNX model to ONNX Script Python code",
description="Convert an ONNX model to ONNX Script Python code",
)
parser_convert.set_defaults(func=convert_command)
parser_convert.add_argument(

Check warning on line 40 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L39-L40

Added lines #L39 - L40 were not covered by tests
"onnx_model_reader",
metavar="ONNX_MODEL_FILE",
type=argparse.FileType("rb"),
)
parser_convert.add_argument(

Check warning on line 45 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L45

Added line #L45 was not covered by tests
"--output",
dest="onnxscript_writer",
metavar="OUTPUT_FILE",
type=argparse.FileType("wb"),
help="file path for writing generated ONNX Script code",
default="-",
required=False,
)

args = parser.parse_args()
args.func(args)

Check warning on line 56 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L55-L56

Added lines #L55 - L56 were not covered by tests


if __name__ == "__main__":
main()

Check warning on line 60 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L60

Added line #L60 was not covered by tests
215 changes: 215 additions & 0 deletions onnxscript/codeanalysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-ancestors
# --------------------------------------------------------------------------

from __future__ import annotations

Check warning on line 9 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L9

Added line #L9 was not covered by tests

import os
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Final, Protocol, Sequence, runtime_checkable

Check warning on line 15 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L11-L15

Added lines #L11 - L15 were not covered by tests

import libcst as cst
import libcst.matchers as cstm
import libcst.metadata as cstmeta

Check warning on line 19 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L17-L19

Added lines #L17 - L19 were not covered by tests

__all__ = [

Check warning on line 21 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L21

Added line #L21 was not covered by tests
"format_code",
"make_name",
"make_import_alias",
"make_const_expr",
"RemoveUnusedImportsTransformer",
"CstCodeGenerator",
]


def format_code(path: Path | None, code: bytes) -> bytes:
try:
import ufmt

Check warning on line 33 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L31-L33

Added lines #L31 - L33 were not covered by tests

if path is None:
path = Path(os.curdir)

Check warning on line 36 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L36

Added line #L36 was not covered by tests

return ufmt.ufmt_bytes(

Check warning on line 38 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L38

Added line #L38 was not covered by tests
path,
code,
black_config=ufmt.util.make_black_config(path),
usort_config=ufmt.UsortConfig.find(path),
)
except ImportError:
return code

Check warning on line 45 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L44-L45

Added lines #L44 - L45 were not covered by tests


def make_name(name: str) -> cst.Attribute | cst.Name:
tokens = name.split(".")
expr: cst.Name | cst.Attribute = cst.Name(tokens[0])

Check warning on line 50 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L48-L50

Added lines #L48 - L50 were not covered by tests
for attr in tokens[1:]:
expr = cst.Attribute(expr, cst.Name(attr))
return expr

Check warning on line 53 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L52-L53

Added lines #L52 - L53 were not covered by tests


def make_import_alias(name: str, asname: str | None = None) -> cst.ImportAlias:
return cst.ImportAlias(

Check warning on line 57 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L56-L57

Added lines #L56 - L57 were not covered by tests
name=make_name(name),
asname=cst.AsName(cst.Name(asname)) if asname else None,
)


def make_const_expr(const: str | int | float) -> cst.BaseExpression:
negate = False

Check warning on line 64 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L63-L64

Added lines #L63 - L64 were not covered by tests
val: cst.Float | cst.Integer

if isinstance(const, str):
return cst.SimpleString('"' + const.replace('"', '\\"') + '"')

Check warning on line 68 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L68

Added line #L68 was not covered by tests
elif isinstance(const, int):
val = cst.Integer(str(abs(const)))
negate = const < 0

Check warning on line 71 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L70-L71

Added lines #L70 - L71 were not covered by tests
elif isinstance(const, float):
val = cst.Float(str(abs(const)))
negate = const < 0

Check warning on line 74 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L73-L74

Added lines #L73 - L74 were not covered by tests
else:
raise NotImplementedError(repr(const))

Check warning on line 76 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L76

Added line #L76 was not covered by tests

if negate:
return cst.UnaryOperation(

Check warning on line 79 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L79

Added line #L79 was not covered by tests
operator=cst.Minus(),
expression=val,
)

return val

Check warning on line 84 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L84

Added line #L84 was not covered by tests


@dataclass
class ImportAlias:
name: str
alias: str | None = None

Check warning on line 90 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L88-L90

Added lines #L88 - L90 were not covered by tests

def to_cst(self) -> cst.ImportAlias:
return cst.ImportAlias(

Check warning on line 93 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L92-L93

Added lines #L92 - L93 were not covered by tests
make_name(self.name), cst.AsName(cst.Name(self.alias)) if self.alias else None
)


@dataclass
class Import:
module: ImportAlias

Check warning on line 100 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L99-L100

Added lines #L99 - L100 were not covered by tests

def to_cst(self) -> cst.Import:
return cst.Import(names=[self.module.to_cst()])

Check warning on line 103 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L102-L103

Added lines #L102 - L103 were not covered by tests


@dataclass
class ImportFrom:
module: str
names: list[ImportAlias]

Check warning on line 109 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L107-L109

Added lines #L107 - L109 were not covered by tests

def to_cst(self) -> cst.ImportFrom:

Check warning on line 111 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L111

Added line #L111 was not covered by tests
return cst.ImportFrom(
module=make_name(self.module),
names=[name.to_cst() for name in self.names],
)


@runtime_checkable
class ScopeAnalyzer(Protocol):
def analyze_scopes(self, scopes: set[cstmeta.Scope]):
pass

Check warning on line 121 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L119-L121

Added lines #L119 - L121 were not covered by tests


class RemoveUnusedImportsTransformer(cst.CSTTransformer, ScopeAnalyzer):
def __init__(self):
self.__unused_imports: dict[cst.Import | cst.ImportFrom, set[str]] = defaultdict(set)

Check warning on line 126 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L124-L126

Added lines #L124 - L126 were not covered by tests

def is_unused_allowed(self, node: cst.Import | cst.ImportFrom, name: str):
return name == "annotations" and cstm.matches(

Check warning on line 129 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L128-L129

Added lines #L128 - L129 were not covered by tests
node, cstm.ImportFrom(module=cstm.Name("__future__"))
)

def analyze_scopes(self, scopes: set[cstmeta.Scope]):

Check warning on line 133 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L133

Added line #L133 was not covered by tests
for scope in scopes:
for assignment in scope.assignments:
if (
isinstance(assignment, cstmeta.Assignment)
and isinstance(node := assignment.node, (cst.Import, cst.ImportFrom))
and len(assignment.references) == 0
and not self.is_unused_allowed(node, assignment.name)
):
self.__unused_imports[node].add(assignment.name)

Check warning on line 142 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L142

Added line #L142 was not covered by tests

def __leave_import_alike(

Check warning on line 144 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L144

Added line #L144 was not covered by tests
self,
original_node: cst.Import | cst.ImportFrom,
updated_node: cst.Import | cst.ImportFrom,
) -> cst.Import | cst.ImportFrom | cst.RemovalSentinel:
if original_node not in self.__unused_imports or isinstance(
updated_node.names, cst.ImportStar
):
return updated_node

Check warning on line 152 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L152

Added line #L152 was not covered by tests

names_to_keep: list[cst.ImportAlias] = []

Check warning on line 154 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L154

Added line #L154 was not covered by tests

for name in updated_node.names:
if name.asname is not None:
if not isinstance(name.asname, cst.Name):
continue
name_value = name.asname.name.value

Check warning on line 160 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L159-L160

Added lines #L159 - L160 were not covered by tests
else:
name_value = name.name.value

Check warning on line 162 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L162

Added line #L162 was not covered by tests
if name_value not in self.__unused_imports[original_node]:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))

Check warning on line 164 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L164

Added line #L164 was not covered by tests

if len(names_to_keep) == 0:
return cst.RemoveFromParent()

Check warning on line 167 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L167

Added line #L167 was not covered by tests

return updated_node.with_changes(names=names_to_keep)

Check warning on line 169 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L169

Added line #L169 was not covered by tests

def leave_Import(self, original_node: cst.Import, updated_node: cst.Import):
return self.__leave_import_alike(original_node, updated_node)

Check warning on line 172 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L171-L172

Added lines #L171 - L172 were not covered by tests

def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom):
return self.__leave_import_alike(original_node, updated_node)

Check warning on line 175 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L174-L175

Added lines #L174 - L175 were not covered by tests


class CstCodeGenerator:
def __init__(self):
self.__imports: Final[list[Import | ImportFrom]] = []

Check warning on line 180 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L178-L180

Added lines #L178 - L180 were not covered by tests

def add_import(self, module: str, alias: str | None = None):

Check warning on line 182 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L182

Added line #L182 was not covered by tests
if not any(
isinstance(imp, Import) and imp.module.name == module and imp.module.alias == alias
for imp in self.__imports
):
self.__imports.append(Import(ImportAlias(module, alias)))

Check warning on line 187 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L187

Added line #L187 was not covered by tests

def add_import_from(self, module: str, name: str, alias: str | None = None):

Check warning on line 189 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L189

Added line #L189 was not covered by tests
for imp in self.__imports:
if isinstance(imp, ImportFrom) and imp.module == module:
for existing in imp.names:
if existing.name == name and existing.alias == alias:
return
imp.names.append(ImportAlias(name, alias))
return
self.__imports.append(ImportFrom(module, [ImportAlias(name, alias)]))

Check warning on line 197 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L194-L197

Added lines #L194 - L197 were not covered by tests

def make_import_statements(self) -> Sequence[cst.SimpleStatementLine]:

Check warning on line 199 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L199

Added line #L199 was not covered by tests
return [cst.SimpleStatementLine(body=[imp.to_cst()]) for imp in self.__imports]

def apply_transformers(

Check warning on line 202 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L202

Added line #L202 was not covered by tests
self, module: cst.Module, transformers: Sequence[cst.CSTTransformer]
) -> cst.Module:
for transformer in transformers:
wrapper = cstmeta.MetadataWrapper(module)

Check warning on line 206 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L206

Added line #L206 was not covered by tests
if isinstance(transformer, ScopeAnalyzer):
scopes = {
scope
for scope in wrapper.resolve(cstmeta.ScopeProvider).values()
if scope is not None
}
transformer.analyze_scopes(scopes)
module = wrapper.visit(transformer)
return module

Check warning on line 215 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L213-L215

Added lines #L213 - L215 were not covered by tests
Loading

0 comments on commit 99a2f51

Please sign in to comment.