Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opsets not imported for functions when the op is used in an if branch #1109

Open
justinchuby opened this issue Oct 24, 2023 · 1 comment · May be fixed by #1115
Open

Opsets not imported for functions when the op is used in an if branch #1109

justinchuby opened this issue Oct 24, 2023 · 1 comment · May be fixed by #1115
Assignees
Labels
bug Something isn't working

Comments

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 24, 2023

The following function

import onnxscript

common_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib.common", version=1)
torchlib_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1)

@onnxscript.script(common_opset)
def IsScalar(input):
    """Return whether the input has rank 0, or is a scalar."""

    return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))


@onnxscript.script(torchlib_opset)
def aten_clamp_max(self, max_):
    """clamp_max(Tensor self, Tensor max) -> Tensor"""

    self_size = op.Size(self)
    max_shape = op.Shape(max_)
    if self_size == 0:
        result = op.Expand(self, max_shape)
    else:
        if IsScalar(max_):
            max_ = op.CastLike(max_, self)
            result = op.Clip(self, None, max_)
        else:
            result = op.Min(self, max_)

    return result

where IsScalar is an OnnxFunxtion from a custom opset does not have that opset imported for the function. I notice IsScalar is used in an if branch/subgraph so that may be the issue.

Generated model:

E   <
E      ir_version: 8,
E      opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1],
E      producer_name: "pytorch",
E      producer_version: "2.2.0"
E   >
E   main_graph (float16[5] input_0, float16 input_1) => (float16[5] _val_2) {
E      _val_2 = pkg.onnxscript.torch_lib.aten_clamp_max (input_0, input_1)
E   }
E   <
E     domain: "pkg.onnxscript.torch_lib",
E     opset_import: ["" : 18]
E   >
E   aten_clamp_max (self, max_) => (result_5)
E   {
E      self_size = Size (self)
E      max_shape = Shape (max_)
E      int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
E      int64_0_cast = CastLike (int64_0, self_size)
E      cond = Equal (self_size, int64_0_cast)
E      result_5 = If (cond) <then_branch: graph = thenGraph_7 () => ( result) {
E         result = Expand (self, max_shape)
E      }, else_branch: graph = elseGraph_7 () => ( result_4) {
E         cond_0 = pkg.onnxscript.torch_lib.common.IsScalar (max_)
E         result_4 = If (cond_0) <then_branch: graph = thenGraph_10 () => ( result_2) {
E            max__1 = CastLike (max_, self)
E            result_2 = Clip (self, , max__1)
E         }, else_branch: graph = elseGraph_10 () => ( result_3) {
E            result_3 = Min (self, max_)
E         }>
E      }>
E   }
E   <
E     domain: "pkg.onnxscript.torch_lib.common",
E     opset_import: ["" : 18]
E   >
E   Rank (input) => (return_val)
E   {
E      tmp = Shape (input)
E      return_val = Size (tmp)
E   }
E   <
E     domain: "pkg.onnxscript.torch_lib.common",
E     opset_import: ["" : 18]
E   >
E   IsScalar (input) => (return_val)
E   {
E      tmp = Shape (input)
E      tmp_0 = Size (tmp)
E      tmp_1 = Constant <value_int: int = 0> ()
E      return_val = Equal (tmp_0, tmp_1)
E   }

Original issue onnx/onnx#5701

cc @gramalingam

@gramalingam
Copy link
Collaborator

Looking. I wonder if this is related to this issue/comment: #59 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants