diff --git a/.gitignore b/.gitignore index 9e6f1a45c..203580c52 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ coverage.xml cover/ test-output.xml *.sarif +_dump_* # Sphinx documentation docs/_build/ diff --git a/onnxscript/function_libs/torch_lib/backward_test.py b/onnxscript/function_libs/torch_lib/backward_test.py new file mode 100644 index 000000000..25e0c2e27 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/backward_test.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable, unbalanced-tuple-unpacking + +import copy +import sys +import unittest + +import torch + +import onnxscript.tools.training_helper +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.llama +from onnxscript._internal.version_utils import has_transformers, torch_older_than + + +class TestBackward(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + def test_backward_working(self): + class SimpleCNNN(torch.nn.Module): + def __init__(self): + super().__init__() + + self.fc1 = torch.nn.Linear(14, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.fc1(x)) + + input_tensors = (torch.randn(1, 1, 14, 14),) + model = SimpleCNNN() + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) + + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) + + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_testbw_working", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) + + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + @unittest.skipIf(True, reason="aten.conv_backward not implemented yet.") + def test_backward_conv(self): + class SimpleCNNN(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv1 = torch.nn.Conv2d( + in_channels=1, out_channels=2, kernel_size=3, padding=1 + ) + self.fc1 = torch.nn.Linear(14, 10) + + def forward(self, x): + y = torch.nn.functional.relu(self.conv1(x)) + z = self.fc1(y) + return z + + input_tensors = (torch.randn(1, 1, 14, 14),) + model = SimpleCNNN() + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) + + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) + + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_testbw_conv", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) + + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/tools/training_helper.py b/onnxscript/tools/training_helper.py index 785b2e6fb..de26dd782 100644 --- a/onnxscript/tools/training_helper.py +++ b/onnxscript/tools/training_helper.py @@ -2,13 +2,19 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + +import glob +import os +from typing import Any + import torch from torch.onnx import ExportOptions from torch.onnx import _OrtBackend as OrtBackend from torch.onnx import _OrtBackendOptions as OrtBackendOptions -def make_aot_ort(dynamic: bool = False): +def make_aot_ort(dynamic: bool = False) -> Any: """Implements an autograd backend for torch.compile based on onnxrt backend.""" export_options = ExportOptions(dynamic_shapes=dynamic) options = OrtBackendOptions(export_options=export_options) @@ -16,8 +22,39 @@ def make_aot_ort(dynamic: bool = False): return ort_backend -def train_loop(model, *args, loss_fn=None, optimizer=None): - """Implements a training loop to be used in tests.""" +def train_loop( + model: Any, + *args, + loss_fn: Any | None = None, + optimizer: Any | None = None, + dump_onnx_models: bool = False, + dump_prefix: str = "dump_train_loop", + dump_clean_first: bool = True, +) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]: + """Implements a training loop to be used in tests. + The function returns the forward output and gradients in a tuple. + + if dump_onnx_models is True, the function returns the forward output, + the gradients in a tuple and the generated onnx_files. + If there is no graph break, there should be + two graphs, one for forward, one for backward. + + Args: + model: pytorch model + args: inputs + loss_fn: loss function, default is MSELoss + optimizer: optimizer, default is SGD + dump_onnx_models: dumps the model onnxrt backend is producing + dump_prefix: names will be `0.onnx`, `1.onnx`, ... + dump_clean_first: clean all files starting with the given prefix + + Returns: + - the forward outputs + - the backwards gradients + - the dumped onnx models, 2 at least unless the forward, backward + were called before this function is executed or if the model + is not a compiled model + """ if loss_fn is None: loss_fn = torch.nn.MSELoss() @@ -28,6 +65,16 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): # Unnecessary in this situation but added for best practices model.train() + if dump_onnx_models: + if dump_clean_first: + names = glob.glob(f"{dump_prefix}*") + for name in names: + os.remove(name) + + old_value = os.environ.get("ONNXRT_DUMP_PATH", None) + os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_forward" + existing_files = glob.glob(f"{dump_prefix}*.onnx") + # Compute prediction and loss pred = model(*args) if isinstance(pred, tuple): @@ -39,6 +86,8 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): loss = loss_fn(v, torch.ones_like(v)) # Backpropagation + if dump_onnx_models: + os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_backward" loss.backward() optimizer.step() # skip that part to retrieve the gradients @@ -47,4 +96,14 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): # returns the gradients res = tuple(p.grad for p in model.parameters() if p.grad is not None) assert len(res) > 0, f"No gradient, loss is {loss}" - return res + + if dump_onnx_models: + if old_value is None: + del os.environ["ONNXRT_DUMP_PATH"] + else: + os.environ["ONNXRT_DUMP_PATH"] = old_value + new_files = glob.glob(f"{dump_prefix}*.onnx") + added_files = set(new_files) - set(existing_files) + return pred, res, [f for f in new_files if f in added_files] + + return pred, res diff --git a/onnxscript/tools/transformers_models/llama_test.py b/onnxscript/tools/transformers_models/llama_test.py index ea4844476..5e697d20d 100644 --- a/onnxscript/tools/transformers_models/llama_test.py +++ b/onnxscript/tools/transformers_models/llama_test.py @@ -120,8 +120,6 @@ def test_llama_dort_static(self): onnxscript.tools.transformers_models.llama.get_llama_model() ) input_tensors = input_tensors_many[0] - expected = model(*input_tensors) - local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) compiled_model = torch.compile( @@ -131,8 +129,17 @@ def test_llama_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_llama", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) diff --git a/onnxscript/tools/transformers_models/mistral_test.py b/onnxscript/tools/transformers_models/mistral_test.py index 7498b9a15..3f2ac020e 100644 --- a/onnxscript/tools/transformers_models/mistral_test.py +++ b/onnxscript/tools/transformers_models/mistral_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=not-callable +# pylint: disable=not-callable, unbalanced-tuple-unpacking import copy import sys @@ -122,8 +122,6 @@ def test_mistral_dort_static(self): onnxscript.tools.transformers_models.mistral.get_mistral_model() ) input_tensors = input_tensors_many[0] - expected = model(*input_tensors) - local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) compiled_model = torch.compile( @@ -133,12 +131,23 @@ def test_mistral_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_mistral", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": diff --git a/onnxscript/tools/transformers_models/phi3_test.py b/onnxscript/tools/transformers_models/phi3_test.py index d9adcfd86..3b7081a6a 100644 --- a/onnxscript/tools/transformers_models/phi3_test.py +++ b/onnxscript/tools/transformers_models/phi3_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=not-callable +# pylint: disable=not-callable, unbalanced-tuple-unpacking import copy import sys @@ -123,8 +123,6 @@ def test_phi3_dort_static(self): onnxscript.tools.transformers_models.phi3.get_phi3_model() ) input_tensors = input_tensors_many[0] - expected = model(*input_tensors) - local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) compiled_model = torch.compile( @@ -134,12 +132,23 @@ def test_phi3_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_phi3", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": diff --git a/onnxscript/tools/transformers_models/phi_test.py b/onnxscript/tools/transformers_models/phi_test.py index e835d8b1d..8897f5212 100644 --- a/onnxscript/tools/transformers_models/phi_test.py +++ b/onnxscript/tools/transformers_models/phi_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=not-callable +# pylint: disable=not-callable, unbalanced-tuple-unpacking import copy import sys @@ -87,7 +87,6 @@ def test_phi_export_cuda(self): def test_phi_dort_static(self): model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model() input_tensors = input_tensors_many[0] - expected = model(*input_tensors) local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) @@ -98,12 +97,23 @@ def test_phi_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_phi", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__":