Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test to check backward function for conv, checks there is no graph breaks #1709

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ coverage.xml
cover/
test-output.xml
*.sarif
_dump_*

# Sphinx documentation
docs/_build/
Expand Down
111 changes: 111 additions & 0 deletions onnxscript/function_libs/torch_lib/backward_test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the OpInfo data structure, I have seen a field that says supports_grad or something which may make it easier for us to generate backward tests. @xiaowuhu do you have some ideas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems be a different scenario than the OpInfo way. Here, we need to go through the aot-compile-training-backward process which is an e2e scenario, although it is not a straight forward way. But this requirement will only benefit not more than 20 backward functions, so I think it is OK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. Thanks!

Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import os
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
import sys
import unittest

import torch

import onnxscript.tools.training_helper
import onnxscript.tools.transformers_models
import onnxscript.tools.transformers_models.llama
from onnxscript._internal.version_utils import has_transformers, torch_older_than


class TestBackward(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
def test_backward_working(self):
class SimpleCNNN(torch.nn.Module):
def __init__(self):
super().__init__()

self.fc1 = torch.nn.Linear(14, 10)

def forward(self, x):
return torch.nn.functional.relu(self.fc1(x))

input_tensors = (torch.randn(1, 1, 14, 14),)
model = SimpleCNNN()
local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=True,
)

expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_testbw_working",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)

@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
@unittest.skipIf(True, reason="aten.conv_backward not implemented yet.")
def test_backward_conv(self):
class SimpleCNNN(torch.nn.Module):
def __init__(self):
super().__init__()

Check warning on line 68 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L66-L68

Added lines #L66 - L68 were not covered by tests

self.conv1 = torch.nn.Conv2d(

Check warning on line 70 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L70

Added line #L70 was not covered by tests
in_channels=1, out_channels=2, kernel_size=3, padding=1
)
self.fc1 = torch.nn.Linear(14, 10)

Check warning on line 73 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L73

Added line #L73 was not covered by tests

def forward(self, x):
y = torch.nn.functional.relu(self.conv1(x))
z = self.fc1(y)
return z

Check warning on line 78 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L75-L78

Added lines #L75 - L78 were not covered by tests

input_tensors = (torch.randn(1, 1, 14, 14),)
model = SimpleCNNN()
local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

Check warning on line 82 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L80-L82

Added lines #L80 - L82 were not covered by tests

compiled_model = torch.compile(

Check warning on line 84 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L84

Added line #L84 was not covered by tests
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=True,
)

expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(

Check warning on line 91 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L91

Added line #L91 was not covered by tests
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning on line 94 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L94

Added line #L94 was not covered by tests
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_testbw_conv",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

Check warning on line 101 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L101

Added line #L101 was not covered by tests

# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(

Check warning on line 105 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L104-L105

Added lines #L104 - L105 were not covered by tests
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 111 in onnxscript/function_libs/torch_lib/backward_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/backward_test.py#L111

Added line #L111 was not covered by tests
67 changes: 63 additions & 4 deletions onnxscript/tools/training_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,59 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

import glob
import os
from typing import Any

import torch
from torch.onnx import ExportOptions
from torch.onnx import _OrtBackend as OrtBackend
from torch.onnx import _OrtBackendOptions as OrtBackendOptions


def make_aot_ort(dynamic: bool = False):
def make_aot_ort(dynamic: bool = False) -> Any:
"""Implements an autograd backend for torch.compile based on onnxrt backend."""
export_options = ExportOptions(dynamic_shapes=dynamic)
options = OrtBackendOptions(export_options=export_options)
ort_backend = OrtBackend(options=options)
return ort_backend


def train_loop(model, *args, loss_fn=None, optimizer=None):
"""Implements a training loop to be used in tests."""
def train_loop(
model: Any,
*args,
loss_fn: Any | None = None,
optimizer: Any | None = None,
dump_onnx_models: bool = False,
dump_prefix: str = "dump_train_loop",
dump_clean_first: bool = True,
) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]:
Comment on lines +25 to +33

Check notice

Code scanning / CodeQL

Returning tuples with varying lengths Note

train_loop returns
tuple of size 2
and
tuple of size 3
.
"""Implements a training loop to be used in tests.
The function returns the forward output and gradients in a tuple.

if dump_onnx_models is True, the function returns the forward output,
the gradients in a tuple and the generated onnx_files.
If there is no graph break, there should be
two graphs, one for forward, one for backward.

Args:
model: pytorch model
args: inputs
loss_fn: loss function, default is MSELoss
optimizer: optimizer, default is SGD
dump_onnx_models: dumps the model onnxrt backend is producing
dump_prefix: names will be `<dump_prefix>0.onnx`, `<dump_prefix>1.onnx`, ...
dump_clean_first: clean all files starting with the given prefix

Returns:
- the forward outputs
- the backwards gradients
- the dumped onnw models, 2 at least unless the forward, backward
were called before this function is executed or if the model
is not a compiled model
"""

if loss_fn is None:
loss_fn = torch.nn.MSELoss()
Expand All @@ -28,6 +65,16 @@
# Unnecessary in this situation but added for best practices
model.train()

if dump_onnx_models:
if dump_clean_first:
names = glob.glob(f"{dump_prefix}*")
for name in names:
os.remove(name)

Check warning on line 72 in onnxscript/tools/training_helper.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/training_helper.py#L72

Added line #L72 was not covered by tests

old_value = os.environ.get("ONNXRT_DUMP_PATH", None)
os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_forward"
existing_files = glob.glob(f"{dump_prefix}*.onnx")

# Compute prediction and loss
pred = model(*args)
if isinstance(pred, tuple):
Expand All @@ -39,6 +86,8 @@
loss = loss_fn(v, torch.ones_like(v))

# Backpropagation
if dump_onnx_models:
os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_backward"
loss.backward()
optimizer.step()
# skip that part to retrieve the gradients
Expand All @@ -47,4 +96,14 @@
# returns the gradients
res = tuple(p.grad for p in model.parameters() if p.grad is not None)
assert len(res) > 0, f"No gradient, loss is {loss}"
return res

if dump_onnx_models:
if old_value is None:
del os.environ["ONNXRT_DUMP_PATH"]
else:
os.environ["ONNXRT_DUMP_PATH"] = old_value

Check warning on line 104 in onnxscript/tools/training_helper.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/training_helper.py#L104

Added line #L104 was not covered by tests
new_files = glob.glob(f"{dump_prefix}*.onnx")
added_files = set(new_files) - set(existing_files)
return pred, res, [f for f in new_files if f in added_files]

return pred, res
23 changes: 16 additions & 7 deletions onnxscript/tools/transformers_models/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@
onnxscript.tools.transformers_models.llama.get_llama_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
Expand All @@ -131,12 +129,23 @@
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable onnx\_models is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'onnx_models' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_llama",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)
# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
25 changes: 17 additions & 8 deletions onnxscript/tools/transformers_models/mistral_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import sys
Expand Down Expand Up @@ -122,8 +122,6 @@
onnxscript.tools.transformers_models.mistral.get_mistral_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
Expand All @@ -133,12 +131,23 @@
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(

Check warning on line 134 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L134

Added line #L134 was not covered by tests
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning on line 137 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L137

Added line #L137 was not covered by tests
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_mistral",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

Check warning on line 144 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L144

Added line #L144 was not covered by tests

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)
# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(

Check warning on line 148 in onnxscript/tools/transformers_models/mistral_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/mistral_test.py#L147-L148

Added lines #L147 - L148 were not covered by tests
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
25 changes: 17 additions & 8 deletions onnxscript/tools/transformers_models/phi3_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import sys
Expand Down Expand Up @@ -123,8 +123,6 @@
onnxscript.tools.transformers_models.phi3.get_phi3_model()
)
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
Expand All @@ -134,12 +132,23 @@
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(

Check warning on line 135 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L135

Added line #L135 was not covered by tests
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(

Check warning on line 138 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L138

Added line #L138 was not covered by tests
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_phi3",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

Check warning on line 145 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L145

Added line #L145 was not covered by tests

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)
# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(

Check warning on line 149 in onnxscript/tools/transformers_models/phi3_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/transformers_models/phi3_test.py#L148-L149

Added lines #L148 - L149 were not covered by tests
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
24 changes: 17 additions & 7 deletions onnxscript/tools/transformers_models/phi_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable
# pylint: disable=not-callable, unbalanced-tuple-unpacking

import copy
import sys
Expand Down Expand Up @@ -87,7 +87,6 @@ def test_phi_export_cuda(self):
def test_phi_dort_static(self):
model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model()
input_tensors = input_tensors_many[0]
expected = model(*input_tensors)

local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

Expand All @@ -98,12 +97,23 @@ def test_phi_dort_static(self):
fullgraph=True,
)

results = compiled_model(*input_tensors)
torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5)
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop(
Fixed Show fixed Hide fixed
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop(
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_dort_phi",
dump_clean_first=True,
)
torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors)
gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)
# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_allclose(
expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
Loading