Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test to check backward function for conv, checks there is no graph breaks #1709

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ coverage.xml
cover/
test-output.xml
*.sarif
_dump_*

# Sphinx documentation
docs/_build/
Expand Down
110 changes: 110 additions & 0 deletions onnxscript/function_libs/torch_lib/backward_test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the OpInfo data structure, I have seen a field that says supports_grad or something which may make it easier for us to generate backward tests. @xiaowuhu do you have some ideas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems be a different scenario than the OpInfo way. Here, we need to go through the aot-compile-training-backward process which is an e2e scenario, although it is not a straight forward way. But this requirement will only benefit not more than 20 backward functions, so I think it is OK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. Thanks!

Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import sys
import unittest

import torch

import onnxscript.tools.training_helper
import onnxscript.tools.transformers_models
import onnxscript.tools.transformers_models.llama
from onnxscript._internal.version_utils import has_transformers, torch_older_than


class TestBackward(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
def test_backward_working(self):
class SimpleCNNN(torch.nn.Module):
def __init__(self):
super().__init__()

self.fc1 = torch.nn.Linear(14, 10)

def forward(self, x):
return torch.nn.functional.relu(self.fc1(x))

input_tensors = (torch.randn(1, 1, 14, 14),)
model = SimpleCNNN()
local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=True,
)

expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_testbw_working",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)

@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
@unittest.skipIf(True, reason="aten.conv_backward not implemented yet.")
def test_backward_conv(self):
class SimpleCNNN(torch.nn.Module):
def __init__(self):
super().__init__()

Check warning on line 67 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L65-L67

Added lines #L65 - L67 were not covered by tests

self.conv1 = torch.nn.Conv2d(

Check warning on line 69 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L69

Added line #L69 was not covered by tests
in_channels=1, out_channels=2, kernel_size=3, padding=1
)
self.fc1 = torch.nn.Linear(14, 10)

Check warning on line 72 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L72

Added line #L72 was not covered by tests

def forward(self, x):
y = torch.nn.functional.relu(self.conv1(x))
z = self.fc1(y)
return z

Check warning on line 77 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L74-L77

Added lines #L74 - L77 were not covered by tests

input_tensors = (torch.randn(1, 1, 14, 14),)
model = SimpleCNNN()
local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

Check warning on line 81 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L79-L81

Added lines #L79 - L81 were not covered by tests

compiled_model = torch.compile(

Check warning on line 83 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L83

Added line #L83 was not covered by tests
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=True,
)

expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(

Check warning on line 90 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L90

Added line #L90 was not covered by tests
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning on line 93 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L93

Added line #L93 was not covered by tests
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_testbw_conv",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

Check warning on line 100 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L100

Added line #L100 was not covered by tests

# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(

Check warning on line 104 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L103-L104

Added lines #L103 - L104 were not covered by tests
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 110 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L110

Added line #L110 was not covered by tests
67 changes: 63 additions & 4 deletions onnxscript/tools/training_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,59 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

import glob
import os
from typing import Any

import torch
from torch.onnx import ExportOptions
from torch.onnx import _OrtBackend as OrtBackend
from torch.onnx import _OrtBackendOptions as OrtBackendOptions


def make_aot_ort(dynamic: bool = False):
def make_aot_ort(dynamic: bool = False) -> Any:
"""Implements an autograd backend for torch.compile based on onnxrt backend."""
export_options = ExportOptions(dynamic_shapes=dynamic)
options = OrtBackendOptions(export_options=export_options)
ort_backend = OrtBackend(options=options)
return ort_backend


def train_loop(model, *args, loss_fn=None, optimizer=None):
"""Implements a training loop to be used in tests."""
def train_loop(
model: Any,
*args,
loss_fn: Any | None = None,
optimizer: Any | None = None,
dump_onnx_models: bool = False,
dump_prefix: str = "dump_train_loop",
dump_clean_first: bool = True,
) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]:
Comment on lines +25 to +33

Check notice

Code scanning / CodeQL

Returning tuples with varying lengths Note

train_loop returns
tuple of size 2
and
tuple of size 3
.
"""Implements a training loop to be used in tests.
The function returns the forward output and gradients in a tuple.

if dump_onnx_models is True, the function returns the forward output,
the gradients in a tuple and the generated onnx_files.
If there is no graph break, there should be
two graphs, one for forward, one for backward.

Args:
model: pytorch model
args: inputs
loss_fn: loss function, default is MSELoss
optimizer: optimizer, default is SGD
dump_onnx_models: dumps the model onnxrt backend is producing
dump_prefix: names will be `<dump_prefix>0.onnx`, `<dump_prefix>1.onnx`, ...
dump_clean_first: clean all files starting with the given prefix

Returns:
- the forward outputs
- the backwards gradients
- the dumped onnx models, 2 at least unless the forward, backward
were called before this function is executed or if the model
is not a compiled model
"""

if loss_fn is None:
loss_fn = torch.nn.MSELoss()
Expand All @@ -28,6 +65,16 @@
# Unnecessary in this situation but added for best practices
model.train()

if dump_onnx_models:
if dump_clean_first:
names = glob.glob(f"{dump_prefix}*")
for name in names:
os.remove(name)

Check warning on line 72 in onnxscript/tools/training_helper.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/training_helper.py#L72

Added line #L72 was not covered by tests

old_value = os.environ.get("ONNXRT_DUMP_PATH", None)
os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_forward"
existing_files = glob.glob(f"{dump_prefix}*.onnx")

# Compute prediction and loss
pred = model(*args)
if isinstance(pred, tuple):
Expand All @@ -39,6 +86,8 @@
loss = loss_fn(v, torch.ones_like(v))

# Backpropagation
if dump_onnx_models:
os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_backward"
loss.backward()
optimizer.step()
# skip that part to retrieve the gradients
Expand All @@ -47,4 +96,14 @@
# returns the gradients
res = tuple(p.grad for p in model.parameters() if p.grad is not None)
assert len(res) > 0, f"No gradient, loss is {loss}"
return res

if dump_onnx_models:
if old_value is None:
del os.environ["ONNXRT_DUMP_PATH"]
else:
os.environ["ONNXRT_DUMP_PATH"] = old_value

Check warning on line 104 in onnxscript/tools/training_helper.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/training_helper.py#L104

Added line #L104 was not covered by tests
new_files = glob.glob(f"{dump_prefix}*.onnx")
added_files = set(new_files) - set(existing_files)
return pred, res, [f for f in new_files if f in added_files]

return pred, res
15 changes: 11 additions & 4 deletions onnxscript/tools/transformers_models/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@
onnxscript.tools.transformers_models.llama.get_llama_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
Expand All @@ -131,8 +129,17 @@
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable onnx\_models is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'onnx_models' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_llama",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
Expand Down
25 changes: 17 additions & 8 deletions onnxscript/tools/transformers_models/mistral_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import sys
Expand Down Expand Up @@ -122,8 +122,6 @@
onnxscript.tools.transformers_models.mistral.get_mistral_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
Expand All @@ -133,12 +131,23 @@
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking

Check warning on line 134 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L134

Added line #L134 was not covered by tests
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning on line 137 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L137

Added line #L137 was not covered by tests
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_mistral",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

Check warning on line 144 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L144

Added line #L144 was not covered by tests

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)
# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(

Check warning on line 148 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L147-L148

Added lines #L147 - L148 were not covered by tests
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
25 changes: 17 additions & 8 deletions onnxscript/tools/transformers_models/phi3_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import sys
Expand Down Expand Up @@ -123,8 +123,6 @@
onnxscript.tools.transformers_models.phi3.get_phi3_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
Expand All @@ -134,12 +132,23 @@
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking

Check warning on line 135 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L135

Added line #L135 was not covered by tests
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning on line 138 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L138

Added line #L138 was not covered by tests
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_phi3",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

Check warning on line 145 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L145

Added line #L145 was not covered by tests

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)
# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(

Check warning on line 149 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L148-L149

Added lines #L148 - L149 were not covered by tests
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
24 changes: 17 additions & 7 deletions onnxscript/tools/transformers_models/phi_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import sys
Expand Down Expand Up @@ -87,7 +87,6 @@ def test_phi_export_cuda(self):
def test_phi_dort_static(self):
model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model()
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

Expand All @@ -98,12 +97,23 @@ def test_phi_dort_static(self):
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_phi",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)
# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
Loading