Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a ONNX to ONNX Script code generator based on libcst #873

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions onnxscript/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from __future__ import annotations

Check warning on line 6 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L6

Added line #L6 was not covered by tests

import argparse
from pathlib import Path
from typing import BinaryIO, Protocol

Check warning on line 10 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L8-L10

Added lines #L8 - L10 were not covered by tests

from onnxscript.codeanalysis import onnx_to_onnxscript

Check warning on line 12 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L12

Added line #L12 was not covered by tests


class ConvertCommandArgs(Protocol):
onnx_model_reader: BinaryIO
onnxscript_writer: BinaryIO

Check warning on line 17 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L15-L17

Added lines #L15 - L17 were not covered by tests


def convert_command(args: ConvertCommandArgs):
args.onnxscript_writer.write(

Check warning on line 21 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L20-L21

Added lines #L20 - L21 were not covered by tests
onnx_to_onnxscript.Driver(args.onnx_model_reader).to_python_code(
None
if args.onnxscript_writer.name == "<stdout>"
else Path(args.onnxscript_writer.name)
)
)


def main():
parser = argparse.ArgumentParser(prog="onnxscript")
subparsers = parser.add_subparsers(required=True)

Check warning on line 32 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L30-L32

Added lines #L30 - L32 were not covered by tests

parser_convert = subparsers.add_parser(

Check warning on line 34 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L34

Added line #L34 was not covered by tests
"convert",
help="Convert an ONNX model to ONNX Script Python code",
description="Convert an ONNX model to ONNX Script Python code",
)
parser_convert.set_defaults(func=convert_command)
parser_convert.add_argument(

Check warning on line 40 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L39-L40

Added lines #L39 - L40 were not covered by tests
"onnx_model_reader",
metavar="ONNX_MODEL_FILE",
type=argparse.FileType("rb"),
)
parser_convert.add_argument(

Check warning on line 45 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L45

Added line #L45 was not covered by tests
"--output",
dest="onnxscript_writer",
metavar="OUTPUT_FILE",
type=argparse.FileType("wb"),
help="file path for writing generated ONNX Script code",
default="-",
required=False,
)

args = parser.parse_args()
args.func(args)

Check warning on line 56 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L55-L56

Added lines #L55 - L56 were not covered by tests


if __name__ == "__main__":
main()

Check warning on line 60 in onnxscript/__main__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/__main__.py#L60

Added line #L60 was not covered by tests
215 changes: 215 additions & 0 deletions onnxscript/codeanalysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-ancestors
# --------------------------------------------------------------------------

from __future__ import annotations

Check warning on line 9 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L9

Added line #L9 was not covered by tests

import os
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Final, Protocol, Sequence, runtime_checkable

Check warning on line 15 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L11-L15

Added lines #L11 - L15 were not covered by tests

import libcst as cst
import libcst.matchers as cstm
import libcst.metadata as cstmeta

Check warning on line 19 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L17-L19

Added lines #L17 - L19 were not covered by tests

__all__ = [

Check warning on line 21 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L21

Added line #L21 was not covered by tests
"format_code",
"make_name",
"make_import_alias",
"make_const_expr",
"RemoveUnusedImportsTransformer",
"CstCodeGenerator",
]


def format_code(path: Path | None, code: bytes) -> bytes:
try:
import ufmt

Check warning on line 33 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L31-L33

Added lines #L31 - L33 were not covered by tests

if path is None:
path = Path(os.curdir)

Check warning on line 36 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L36

Added line #L36 was not covered by tests

return ufmt.ufmt_bytes(

Check warning on line 38 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L38

Added line #L38 was not covered by tests
path,
code,
black_config=ufmt.util.make_black_config(path),
usort_config=ufmt.UsortConfig.find(path),
)
except ImportError:
return code

Check warning on line 45 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L44-L45

Added lines #L44 - L45 were not covered by tests


def make_name(name: str) -> cst.Attribute | cst.Name:
tokens = name.split(".")
expr: cst.Name | cst.Attribute = cst.Name(tokens[0])

Check warning on line 50 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L48-L50

Added lines #L48 - L50 were not covered by tests
for attr in tokens[1:]:
expr = cst.Attribute(expr, cst.Name(attr))
return expr

Check warning on line 53 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L52-L53

Added lines #L52 - L53 were not covered by tests


def make_import_alias(name: str, asname: str | None = None) -> cst.ImportAlias:
return cst.ImportAlias(

Check warning on line 57 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L56-L57

Added lines #L56 - L57 were not covered by tests
name=make_name(name),
asname=cst.AsName(cst.Name(asname)) if asname else None,
)


def make_const_expr(const: str | int | float) -> cst.BaseExpression:
negate = False

Check warning on line 64 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L63-L64

Added lines #L63 - L64 were not covered by tests
val: cst.Float | cst.Integer

if isinstance(const, str):
return cst.SimpleString('"' + const.replace('"', '\\"') + '"')

Check warning on line 68 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L68

Added line #L68 was not covered by tests
elif isinstance(const, int):
val = cst.Integer(str(abs(const)))
negate = const < 0

Check warning on line 71 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L70-L71

Added lines #L70 - L71 were not covered by tests
elif isinstance(const, float):
val = cst.Float(str(abs(const)))
negate = const < 0

Check warning on line 74 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L73-L74

Added lines #L73 - L74 were not covered by tests
else:
raise NotImplementedError(repr(const))

Check warning on line 76 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L76

Added line #L76 was not covered by tests

if negate:
return cst.UnaryOperation(

Check warning on line 79 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L79

Added line #L79 was not covered by tests
operator=cst.Minus(),
expression=val,
)

return val

Check warning on line 84 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L84

Added line #L84 was not covered by tests


@dataclass
class ImportAlias:
name: str
alias: str | None = None

Check warning on line 90 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L88-L90

Added lines #L88 - L90 were not covered by tests

def to_cst(self) -> cst.ImportAlias:
return cst.ImportAlias(

Check warning on line 93 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L92-L93

Added lines #L92 - L93 were not covered by tests
make_name(self.name), cst.AsName(cst.Name(self.alias)) if self.alias else None
)


@dataclass
class Import:
module: ImportAlias

Check warning on line 100 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L99-L100

Added lines #L99 - L100 were not covered by tests

def to_cst(self) -> cst.Import:
return cst.Import(names=[self.module.to_cst()])

Check warning on line 103 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L102-L103

Added lines #L102 - L103 were not covered by tests


@dataclass
class ImportFrom:
module: str
names: list[ImportAlias]

Check warning on line 109 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L107-L109

Added lines #L107 - L109 were not covered by tests

def to_cst(self) -> cst.ImportFrom:

Check warning on line 111 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L111

Added line #L111 was not covered by tests
return cst.ImportFrom(
module=make_name(self.module),
names=[name.to_cst() for name in self.names],
)


@runtime_checkable
class ScopeAnalyzer(Protocol):
def analyze_scopes(self, scopes: set[cstmeta.Scope]):
pass

Check warning on line 121 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L119-L121

Added lines #L119 - L121 were not covered by tests


class RemoveUnusedImportsTransformer(cst.CSTTransformer, ScopeAnalyzer):
def __init__(self):
self.__unused_imports: dict[cst.Import | cst.ImportFrom, set[str]] = defaultdict(set)

Check warning on line 126 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L124-L126

Added lines #L124 - L126 were not covered by tests

def is_unused_allowed(self, node: cst.Import | cst.ImportFrom, name: str):
return name == "annotations" and cstm.matches(

Check warning on line 129 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L128-L129

Added lines #L128 - L129 were not covered by tests
node, cstm.ImportFrom(module=cstm.Name("__future__"))
)

def analyze_scopes(self, scopes: set[cstmeta.Scope]):

Check warning on line 133 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L133

Added line #L133 was not covered by tests
for scope in scopes:
for assignment in scope.assignments:
if (
isinstance(assignment, cstmeta.Assignment)
and isinstance(node := assignment.node, (cst.Import, cst.ImportFrom))
and len(assignment.references) == 0
and not self.is_unused_allowed(node, assignment.name)
):
self.__unused_imports[node].add(assignment.name)

Check warning on line 142 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L142

Added line #L142 was not covered by tests

def __leave_import_alike(

Check warning on line 144 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L144

Added line #L144 was not covered by tests
self,
original_node: cst.Import | cst.ImportFrom,
updated_node: cst.Import | cst.ImportFrom,
) -> cst.Import | cst.ImportFrom | cst.RemovalSentinel:
if original_node not in self.__unused_imports or isinstance(
updated_node.names, cst.ImportStar
):
return updated_node

Check warning on line 152 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L152

Added line #L152 was not covered by tests

names_to_keep: list[cst.ImportAlias] = []

Check warning on line 154 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L154

Added line #L154 was not covered by tests

for name in updated_node.names:
if name.asname is not None:
if not isinstance(name.asname, cst.Name):
continue
name_value = name.asname.name.value

Check warning on line 160 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L159-L160

Added lines #L159 - L160 were not covered by tests
else:
name_value = name.name.value

Check warning on line 162 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L162

Added line #L162 was not covered by tests
if name_value not in self.__unused_imports[original_node]:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))

Check warning on line 164 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L164

Added line #L164 was not covered by tests

if len(names_to_keep) == 0:
return cst.RemoveFromParent()

Check warning on line 167 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L167

Added line #L167 was not covered by tests

return updated_node.with_changes(names=names_to_keep)

Check warning on line 169 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L169

Added line #L169 was not covered by tests

def leave_Import(self, original_node: cst.Import, updated_node: cst.Import):
return self.__leave_import_alike(original_node, updated_node)

Check warning on line 172 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L171-L172

Added lines #L171 - L172 were not covered by tests

def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom):
return self.__leave_import_alike(original_node, updated_node)

Check warning on line 175 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L174-L175

Added lines #L174 - L175 were not covered by tests


class CstCodeGenerator:
def __init__(self):
self.__imports: Final[list[Import | ImportFrom]] = []

Check warning on line 180 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L178-L180

Added lines #L178 - L180 were not covered by tests

def add_import(self, module: str, alias: str | None = None):

Check warning on line 182 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L182

Added line #L182 was not covered by tests
if not any(
isinstance(imp, Import) and imp.module.name == module and imp.module.alias == alias
for imp in self.__imports
):
self.__imports.append(Import(ImportAlias(module, alias)))

Check warning on line 187 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L187

Added line #L187 was not covered by tests

def add_import_from(self, module: str, name: str, alias: str | None = None):

Check warning on line 189 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L189

Added line #L189 was not covered by tests
for imp in self.__imports:
if isinstance(imp, ImportFrom) and imp.module == module:
for existing in imp.names:
if existing.name == name and existing.alias == alias:
return
imp.names.append(ImportAlias(name, alias))
return
self.__imports.append(ImportFrom(module, [ImportAlias(name, alias)]))

Check warning on line 197 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L194-L197

Added lines #L194 - L197 were not covered by tests

def make_import_statements(self) -> Sequence[cst.SimpleStatementLine]:

Check warning on line 199 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L199

Added line #L199 was not covered by tests
return [cst.SimpleStatementLine(body=[imp.to_cst()]) for imp in self.__imports]

def apply_transformers(

Check warning on line 202 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L202

Added line #L202 was not covered by tests
self, module: cst.Module, transformers: Sequence[cst.CSTTransformer]
) -> cst.Module:
for transformer in transformers:
wrapper = cstmeta.MetadataWrapper(module)

Check warning on line 206 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L206

Added line #L206 was not covered by tests
if isinstance(transformer, ScopeAnalyzer):
scopes = {
scope
for scope in wrapper.resolve(cstmeta.ScopeProvider).values()
if scope is not None
}
transformer.analyze_scopes(scopes)
module = wrapper.visit(transformer)
return module

Check warning on line 215 in onnxscript/codeanalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/codeanalysis/__init__.py#L213-L215

Added lines #L213 - L215 were not covered by tests
Loading
Loading