Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VSTHRD114: Do not return null from non-async Task method #596

Merged
merged 14 commits into from
Apr 19, 2020
Merged
31 changes: 31 additions & 0 deletions doc/analyzers/VSTHRD112.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# VSTHRD112 Avoid returning a null Task

Returning `null` from a non-async `Task`/`Task<T>` method will cause a `NullReferenceException` at runtime. This problem can be avoided by returning `Task.CompletedTask`, `Task.FromResult<T>(null)` or `Task.FromResult(default(T))` instead.

## Examples of patterns that are flagged by this analyzer

Any non-async `Task` returning method with an explicit `return null;` will be flagged.

```csharp
Task DoAsync() {
return null;
}

Task<object> GetSomethingAsync() {
return null;
}
```

## Solution

Return a task like `Task.CompletedTask` or `Task.FromResult`.

```csharp
Task DoAsync() {
return Task.CompletedTask;
}

Task<object> GetSomethingAsync() {
return Task.FromResult<object>(null);
}
```
1 change: 1 addition & 0 deletions doc/analyzers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ID | Title | Severity | Supports | Default diagnostic severity
[VSTHRD109](VSTHRD109.md) | Switch instead of assert in async methods | Advisory | [1st rule](../threading_rules.md#Rule1) | Error
[VSTHRD110](VSTHRD110.md) | Observe result of async calls | Advisory | | Warning
[VSTHRD111](VSTHRD111.md) | Use `.ConfigureAwait(bool)` | Advisory | | Hidden
[VSTHRD112](VSTHRD112.md) | Avoid returning null from a `Task`-returning method. | Advisory | | Warning
[VSTHRD200](VSTHRD200.md) | Use `Async` naming convention | Guideline | [VSTHRD103](VSTHRD103.md) | Warning

## Severity descriptions
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
namespace Microsoft.VisualStudio.Threading.Analyzers
{
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Simplification;

[ExportCodeFixProvider(LanguageNames.CSharp)]
public class VSTHRD112AvoidReturningNullTaskCodeFix : CodeFixProvider
{
private static readonly ImmutableArray<string> ReusableFixableDiagnosticIds = ImmutableArray.Create(
VSTHRD112AvoidReturningNullTaskAnalyzer.Id);

/// <inheritdoc />
public override ImmutableArray<string> FixableDiagnosticIds => ReusableFixableDiagnosticIds;

/// <inheritdoc />
public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
foreach (var diagnostic in context.Diagnostics)
{
var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false);
var syntaxRoot = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);

var nullLiteral = syntaxRoot.FindNode(diagnostic.Location.SourceSpan) as LiteralExpressionSyntax;
if (nullLiteral == null)
{
continue;
}

var methodDeclaration = nullLiteral.FirstAncestorOrSelf<MethodDeclarationSyntax>();
if (methodDeclaration == null)
{
continue;
}

if (!(methodDeclaration.ReturnType is GenericNameSyntax genericReturnType))
{
context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD112_CodeFix_CompletedTask, ct => ApplyTaskCompletedTaskFix(ct), "CompletedTask"), diagnostic);
}
else
{
if (genericReturnType.TypeArgumentList.Arguments.Count != 1)
{
continue;
}

context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD112_CodeFix_FromResult, ct => ApplyTaskFromResultFix(genericReturnType.TypeArgumentList.Arguments[0], ct), "FromResult"), diagnostic);
}

Task<Document> ApplyTaskCompletedTaskFix(CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.IdentifierName("CompletedTask"))
.WithAdditionalAnnotations(Simplifier.Annotation);

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}

Task<Document> ApplyTaskFromResultFix(TypeSyntax returnTypeArgument, CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.GenericName("FromResult").AddTypeArgumentListArguments(returnTypeArgument)))
.AddArgumentListArguments(SyntaxFactory.Argument(nullLiteral))
.WithAdditionalAnnotations(Simplifier.Annotation);

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,38 @@ public Test()

this.SolutionTransforms.Add((solution, projectId) =>
{
var parseOptions = (CSharpParseOptions)solution.GetProject(projectId).ParseOptions;
solution = solution.WithProjectParseOptions(projectId, parseOptions.WithLanguageVersion(LanguageVersion.CSharp7_1));
Project? project = solution.GetProject(projectId);

var parseOptions = (CSharpParseOptions)project!.ParseOptions;
Evangelink marked this conversation as resolved.
Show resolved Hide resolved
project = project.WithParseOptions(parseOptions.WithLanguageVersion(LanguageVersion.CSharp7_1));

if (this.HasEntryPoint)
{
var compilationOptions = solution.GetProject(projectId).CompilationOptions;
solution = solution.WithProjectCompilationOptions(projectId, compilationOptions.WithOutputKind(OutputKind.ConsoleApplication));
project = project.WithCompilationOptions(project.CompilationOptions.WithOutputKind(OutputKind.ConsoleApplication));
}

if (this.IncludeMicrosoftVisualStudioThreading)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(JoinableTaskFactory).Assembly.Location));
}

if (this.IncludeWindowsBase)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(Dispatcher).Assembly.Location));
}

if (this.IncludeVisualStudioSdk)
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(typeof(IOleServiceProvider).Assembly.Location));

var nugetPackagesFolder = Environment.CurrentDirectory;
foreach (var reference in VSSDKPackageReferences)
foreach (var reference in CSharpCodeFixVerifier<TAnalyzer, TCodeFix>.Test.VSSDKPackageReferences)
Evangelink marked this conversation as resolved.
Show resolved Hide resolved
{
solution = solution.AddMetadataReference(projectId, MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
project = project.AddMetadataReference(MetadataReference.CreateFromFile(Path.Combine(nugetPackagesFolder, reference)));
}
}

return solution;
return project.Solution;
});

this.TestState.AdditionalFilesFactories.Add(() =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#nullable enable

namespace Microsoft.VisualStudio.Threading.Analyzers.Tests
{
using System;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.VisualStudio.Threading.Analyzers.Lightup;
using Xunit;

public class LightupHelpersTests
{
[Theory]
[InlineData(null)]
[InlineData(typeof(SyntaxNode))]
public void TestCanAccessNonExistentSyntaxProperty(Type type)
{
var fallbackResult = new object();

var propertyAccessor = LightupHelpers.CreateSyntaxPropertyAccessor<SyntaxNode, object>(type, "NonExistentProperty", fallbackResult);
Assert.NotNull(propertyAccessor);
Assert.Same(fallbackResult, propertyAccessor(SyntaxFactory.AccessorList()));
Assert.Throws<NullReferenceException>(() => propertyAccessor(null!));

var withPropertyAccessor = LightupHelpers.CreateSyntaxWithPropertyAccessor<SyntaxNode, object>(type, "NonExistentProperty", fallbackResult);
Assert.NotNull(withPropertyAccessor);
Assert.NotNull(withPropertyAccessor(SyntaxFactory.AccessorList(), fallbackResult));
Assert.ThrowsAny<NotSupportedException>(() => withPropertyAccessor(SyntaxFactory.AccessorList(), new object()));
Assert.Throws<NullReferenceException>(() => withPropertyAccessor(null!, new object()));
}

[Theory]
[InlineData(null)]
[InlineData(typeof(EmptySymbol))]
public void TestCanAccessNonExistentSymbolProperty(Type type)
{
var fallbackResult = new object();

var propertyAccessor = LightupHelpers.CreateSymbolPropertyAccessor<ISymbol, object>(type, "NonExistentProperty", fallbackResult);
Assert.NotNull(propertyAccessor);
Assert.Same(fallbackResult, propertyAccessor(new EmptySymbol()));
Assert.Throws<NullReferenceException>(() => propertyAccessor(null!));

var withPropertyAccessor = LightupHelpers.CreateSymbolWithPropertyAccessor<ISymbol, object>(type, "NonExistentProperty", fallbackResult);
Assert.NotNull(withPropertyAccessor);
Assert.NotNull(withPropertyAccessor(new EmptySymbol(), fallbackResult));
Assert.ThrowsAny<NotSupportedException>(() => withPropertyAccessor(new EmptySymbol(), new object()));
Assert.Throws<NullReferenceException>(() => withPropertyAccessor(null!, new object()));
}

[Theory]
[InlineData(null)]
[InlineData(typeof(SyntaxNode))]
public void TestCanAccessNonExistentMethodWithArgument(Type type)
{
var fallbackResult = new object();

var accessor = LightupHelpers.CreateAccessorWithArgument<SyntaxNode, int, object?>(type, "parameterName", typeof(int), "argumentName", "NonExistentMethod", fallbackResult);
Assert.NotNull(accessor);
Assert.Same(fallbackResult, accessor(SyntaxFactory.AccessorList(), 0));
Assert.Throws<NullReferenceException>(() => accessor(null!, 0));
}

[Fact]
public void TestCreateSyntaxPropertyAccessor()
{
// The call *should* have been made with the first generic argument set to `BaseMethodDeclarationSyntax`
// instead of `MethodDeclarationSyntax`.
Assert.ThrowsAny<InvalidOperationException>(() => LightupHelpers.CreateSyntaxPropertyAccessor<MethodDeclarationSyntax, BlockSyntax?>(typeof(BaseMethodDeclarationSyntax), nameof(BaseMethodDeclarationSyntax.Body), fallbackResult: null));

// The call *should* have been made with the second generic argument set to `ArrowExpressionClauseSyntax`
// instead of `BlockSyntax`.
Assert.ThrowsAny<InvalidOperationException>(() => LightupHelpers.CreateSyntaxPropertyAccessor<MethodDeclarationSyntax, BlockSyntax?>(typeof(MethodDeclarationSyntax), nameof(MethodDeclarationSyntax.ExpressionBody), fallbackResult: null));
}

[Fact]
public void TestCreateSyntaxWithPropertyAccessor()
{
// The call *should* have been made with the first generic argument set to `BaseMethodDeclarationSyntax`
// instead of `MethodDeclarationSyntax`.
Assert.ThrowsAny<InvalidOperationException>(() => LightupHelpers.CreateSyntaxWithPropertyAccessor<MethodDeclarationSyntax, BlockSyntax?>(typeof(BaseMethodDeclarationSyntax), nameof(BaseMethodDeclarationSyntax.Body), fallbackResult: null));

// The call *should* have been made with the second generic argument set to `ArrowExpressionClauseSyntax`
// instead of `BlockSyntax`.
Assert.ThrowsAny<InvalidOperationException>(() => LightupHelpers.CreateSyntaxWithPropertyAccessor<MethodDeclarationSyntax, BlockSyntax?>(typeof(MethodDeclarationSyntax), nameof(MethodDeclarationSyntax.ExpressionBody), fallbackResult: null));
}

[SuppressMessage("MicrosoftCodeAnalysisCompatibility", "RS1009:Only internal implementations of this interface are allowed.", Justification = "Stub for testing.")]
private class EmptySymbol : ISymbol
{
SymbolKind ISymbol.Kind => throw new NotImplementedException();

string ISymbol.Language => throw new NotImplementedException();

string ISymbol.Name => throw new NotImplementedException();

string ISymbol.MetadataName => throw new NotImplementedException();

ISymbol ISymbol.ContainingSymbol => throw new NotImplementedException();

IAssemblySymbol ISymbol.ContainingAssembly => throw new NotImplementedException();

IModuleSymbol ISymbol.ContainingModule => throw new NotImplementedException();

INamedTypeSymbol ISymbol.ContainingType => throw new NotImplementedException();

INamespaceSymbol ISymbol.ContainingNamespace => throw new NotImplementedException();

bool ISymbol.IsDefinition => throw new NotImplementedException();

bool ISymbol.IsStatic => throw new NotImplementedException();

bool ISymbol.IsVirtual => throw new NotImplementedException();

bool ISymbol.IsOverride => throw new NotImplementedException();

bool ISymbol.IsAbstract => throw new NotImplementedException();

bool ISymbol.IsSealed => throw new NotImplementedException();

bool ISymbol.IsExtern => throw new NotImplementedException();

bool ISymbol.IsImplicitlyDeclared => throw new NotImplementedException();

bool ISymbol.CanBeReferencedByName => throw new NotImplementedException();

ImmutableArray<Location> ISymbol.Locations => throw new NotImplementedException();

ImmutableArray<SyntaxReference> ISymbol.DeclaringSyntaxReferences => throw new NotImplementedException();

Accessibility ISymbol.DeclaredAccessibility => throw new NotImplementedException();

ISymbol ISymbol.OriginalDefinition => throw new NotImplementedException();

bool ISymbol.HasUnsupportedMetadata => throw new NotImplementedException();

void ISymbol.Accept(SymbolVisitor visitor)
=> throw new NotImplementedException();

TResult ISymbol.Accept<TResult>(SymbolVisitor<TResult> visitor)
=> throw new NotImplementedException();

public bool Equals(ISymbol other)
=> throw new NotImplementedException();

ImmutableArray<AttributeData> ISymbol.GetAttributes()
=> throw new NotImplementedException();

string ISymbol.GetDocumentationCommentId()
=> throw new NotImplementedException();

string ISymbol.GetDocumentationCommentXml(CultureInfo preferredCulture, bool expandIncludes, CancellationToken cancellationToken)
=> throw new NotImplementedException();

ImmutableArray<SymbolDisplayPart> ISymbol.ToDisplayParts(SymbolDisplayFormat format)
=> throw new NotImplementedException();

string ISymbol.ToDisplayString(SymbolDisplayFormat format)
=> throw new NotImplementedException();

ImmutableArray<SymbolDisplayPart> ISymbol.ToMinimalDisplayParts(SemanticModel semanticModel, int position, SymbolDisplayFormat format)
=> throw new NotImplementedException();

string ISymbol.ToMinimalDisplayString(SemanticModel semanticModel, int position, SymbolDisplayFormat format)
=> throw new NotImplementedException();

public bool Equals(ISymbol other, SymbolEqualityComparer equalityComparer) => throw new NotImplementedException();
Evangelink marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<PackageReference Include="MicroBuild.VisualStudio" Version="$(MicroBuildVersion)" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.CodeFix.Testing.XUnit" Version="1.0.1-beta1.20059.2" />
<PackageReference Include="Microsoft.CodeAnalysis.VisualBasic.CodeFix.Testing.XUnit" Version="1.0.1-beta1.20059.2" />
<PackageReference Include="Microsoft.CodeAnalysis" Version="2.8.2" />
<PackageReference Include="Microsoft.CodeAnalysis" Version="3.3.0" />
<PackageReference Include="Microsoft.VisualStudio.OLE.Interop" Version="7.10.6070" />
<PackageReference Include="Microsoft.VisualStudio.Shell.14.0" Version="14.3.25407" IncludeAssets="runtime" />
<PackageReference Include="Microsoft.VisualStudio.Shell.Interop.11.0" Version="11.0.61030" IncludeAssets="runtime" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Testing;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Testing;
using Microsoft.CodeAnalysis.Testing.Verifiers;
using Xunit;
using Verify = MultiAnalyzerTests.Verifier;

Expand All @@ -32,7 +30,7 @@ Task<int> FooAsync() {
return Task.FromResult(1);
}
Task BarAsync() => null;
Task BarAsync() => Task.CompletedTask;
static void SetTaskSourceIfCompleted<T>(Task<T> task, TaskCompletionSource<T> tcs) {
if (task.IsCompleted) {
Expand Down Expand Up @@ -171,7 +169,7 @@ public Task BAsync() {
E().ToString();
E()();
string v = nameof(E);
return null;
return Task.CompletedTask;
}
internal Task CAsync() {
Expand All @@ -181,7 +179,7 @@ internal Task CAsync() {
E().ToString();
E()();
string v = nameof(E);
return null;
return Task.CompletedTask;
}
private void D<T>() { }
Expand Down
Loading