Skip to content

Commit

Permalink
Enable nullable everywhere (#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
j3parker authored Aug 7, 2024
1 parent 907b8cc commit 63fed57
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 30 deletions.
1 change: 1 addition & 0 deletions Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
<LangVersion>10.0</LangVersion>
<ImplicitUsings>true</ImplicitUsings>
<InvariantGlobalization>true</InvariantGlobalization>
<Nullable>enable</Nullable>
</PropertyGroup>
</Project>
1 change: 0 additions & 1 deletion src/D2L.CodeStyle.Analyzers/D2L.CodeStyle.Analyzers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<IncludeBuildOutput>False</IncludeBuildOutput>
<IncludeSymbols>True</IncludeSymbols>
<Nullable>enable</Nullable>
<WarningsAsErrors>CA2016,Nullable</WarningsAsErrors>
<TreatWarningsAsErrors>false</TreatWarningsAsErrors>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
Expand Down
14 changes: 14 additions & 0 deletions src/D2L.CodeStyle.TestAnalyzers/Common/NotNullWhenAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#if !NETSTANDARD2_1

namespace System.Diagnostics.CodeAnalysis;

[AttributeUsage( AttributeTargets.Parameter )]
internal sealed class NotNullWhenAttribute : Attribute {
public bool ReturnValue { get; }

public NotNullWhenAttribute( bool returnValue ) {
ReturnValue = returnValue;
}
}

#endif
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Microsoft.CodeAnalysis;

Expand All @@ -9,7 +10,9 @@ public static class RoslynExtensions {

// Copied from the non-test assembly because we do not reference it.

public static bool IsNullOrErrorType( this ITypeSymbol symbol ) {
public static bool IsNullOrErrorType(
[NotNullWhen( false )] this ITypeSymbol? symbol
) {

if( symbol == null ) {
return true;
Expand All @@ -26,7 +29,9 @@ public static bool IsNullOrErrorType( this ITypeSymbol symbol ) {
return false;
}

public static bool IsNullOrErrorType( this ISymbol symbol ) {
public static bool IsNullOrErrorType(
[NotNullWhen( false )] this ISymbol? symbol
) {

if( symbol == null ) {
return true;
Expand Down
67 changes: 45 additions & 22 deletions src/D2L.CodeStyle.TestAnalyzers/NUnit/CategoryAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using D2L.CodeStyle.TestAnalyzers.Common;
Expand Down Expand Up @@ -42,7 +43,7 @@ public override void Initialize( AnalysisContext context ) {
}

private static void OnCompilationStart( CompilationStartAnalysisContext context ) {
if( !TryLoadNUnitTypes( context.Compilation, out NUnitTypes types ) ) {
if( !TryLoadNUnitTypes( context.Compilation, out NUnitTypes? types ) ) {
return;
}

Expand All @@ -53,7 +54,7 @@ private static void OnCompilationStart( CompilationStartAnalysisContext context
context: ctx,
bannedCategories: bannedCategories,
types: types,
syntax: ctx.Node as MethodDeclarationSyntax
syntax: (MethodDeclarationSyntax)ctx.Node
),
SyntaxKind.MethodDeclaration
);
Expand All @@ -69,9 +70,12 @@ private static void OnCompilationEnd( CompilationAnalysisContext context, NUnitT
return;
}

// TODO: can this actually be null?
var syntaxReference = attribute.ApplicationSyntaxReference!;

context.ReportDiagnostic( Diagnostic.Create(
Diagnostics.NUnitCategory,
attribute.ApplicationSyntaxReference.GetSyntax( context.CancellationToken ).GetLocation(),
syntaxReference.GetSyntax( context.CancellationToken ).GetLocation(),
$"Assemblies cannot be categorized as any of [{string.Join( ", ", ProhibitedAssemblyCategories )}], but saw '{category}'."
) );
} );
Expand All @@ -85,7 +89,7 @@ MethodDeclarationSyntax syntax
) {
SemanticModel model = context.SemanticModel;

IMethodSymbol method = model.GetDeclaredSymbol( syntax, context.CancellationToken );
IMethodSymbol? method = model.GetDeclaredSymbol( syntax, context.CancellationToken );
if( method == null ) {
return;
}
Expand Down Expand Up @@ -117,8 +121,8 @@ private static bool IsTestMethod(
IMethodSymbol method
) {
foreach( AttributeData attribute in method.GetAttributes() ) {
INamedTypeSymbol attributeType = attribute.AttributeClass;
if( types.TestAttributes.Contains( attributeType ) ) {
INamedTypeSymbol? attributeType = attribute.AttributeClass;
if( attributeType != null && types.TestAttributes.Contains( attributeType ) ) {
return true;
}
}
Expand Down Expand Up @@ -159,7 +163,7 @@ private static void VisitCategories(
Action<string, AttributeData> visitor
) {
foreach( AttributeData attribute in symbol.GetAttributes() ) {
INamedTypeSymbol attributeType = attribute.AttributeClass;
INamedTypeSymbol? attributeType = attribute.AttributeClass;
if( types.CategoryAttribute.Equals( attributeType, SymbolEqualityComparer.Default ) ) {
VisitCategoryAttribute( attribute, visitor );
continue;
Expand Down Expand Up @@ -189,11 +193,16 @@ Action<string, AttributeData> visitor

TypedConstant arg = args[0];

if( arg.Type == null ) {
return;
}

if( arg.Type.SpecialType != SpecialType.System_String ) {
return;
}

string category = arg.Value as string;
string category = (string?)arg.Value!;

visitor( category, attribute );
}

Expand All @@ -208,11 +217,16 @@ Action<string, AttributeData> visitor

TypedConstant arg = namedArg.Value;

if( arg.Type == null ) {
continue;
}

if( arg.Type.SpecialType != SpecialType.System_String ) {
continue;
}

string categoryCsv = arg.Value as string;
string categoryCsv = (string?)arg.Value!;

foreach( string category in categoryCsv.Split( ',' ) ) {
visitor( category.Trim(), attribute );
}
Expand All @@ -221,24 +235,29 @@ Action<string, AttributeData> visitor

private static bool TryLoadNUnitTypes(
Compilation compilation,
out NUnitTypes types
[NotNullWhen( true )]
out NUnitTypes? types
) {
INamedTypeSymbol categoryAttribute = compilation.GetTypeByMetadataName( "NUnit.Framework.CategoryAttribute" );
INamedTypeSymbol? categoryAttribute = compilation.GetTypeByMetadataName( "NUnit.Framework.CategoryAttribute" );
if( categoryAttribute == null || categoryAttribute.TypeKind == TypeKind.Error ) {
types = null;
return false;
}

ImmutableHashSet<INamedTypeSymbol> testAttributes = ImmutableHashSet
.Create<INamedTypeSymbol>(
SymbolEqualityComparer.Default,
compilation.GetTypeByMetadataName( "NUnit.Framework.TestAttribute" ),
compilation.GetTypeByMetadataName( "NUnit.Framework.TestCaseAttribute" ),
compilation.GetTypeByMetadataName( "NUnit.Framework.TestCaseSourceAttribute" ),
compilation.GetTypeByMetadataName( "NUnit.Framework.TheoryAttribute" )
);
ImmutableHashSet<INamedTypeSymbol> testAttributes = new[] {
compilation.GetTypeByMetadataName( "NUnit.Framework.TestAttribute" ),
compilation.GetTypeByMetadataName( "NUnit.Framework.TestCaseAttribute" ),
compilation.GetTypeByMetadataName( "NUnit.Framework.TestCaseSourceAttribute" ),
compilation.GetTypeByMetadataName( "NUnit.Framework.TheoryAttribute" )
}.Where( x => x != null )!
.ToImmutableHashSet<INamedTypeSymbol>( SymbolEqualityComparer.Default );

INamedTypeSymbol testFixtureAttribute = compilation.GetTypeByMetadataName( "NUnit.Framework.TestFixtureAttribute" );
INamedTypeSymbol? testFixtureAttribute = compilation.GetTypeByMetadataName( "NUnit.Framework.TestFixtureAttribute" );

if( testFixtureAttribute == null ) {
types = null;
return false;
}

types = new NUnitTypes( categoryAttribute, testAttributes, testFixtureAttribute );
return true;
Expand Down Expand Up @@ -268,15 +287,19 @@ AnalyzerOptions options
StringComparer.Ordinal
);

AdditionalText bannedListFile = options.AdditionalFiles.FirstOrDefault(
AdditionalText? bannedListFile = options.AdditionalFiles.FirstOrDefault(
file => Path.GetFileName( file.Path ) == "BannedTestCategoriesList.txt"
);

if( bannedListFile == null ) {
return bannedList.ToImmutableHashSet();
}

SourceText allowedListText = bannedListFile.GetText();
SourceText? allowedListText = bannedListFile.GetText();

if( allowedListText == null ) {
throw new Exception( "Couldn't read config" );
}

foreach( TextLine line in allowedListText.Lines ) {
bannedList.Add( line.ToString().Trim() );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public override void Initialize( AnalysisContext context ) {

private static void Register( CompilationStartAnalysisContext context ) {

INamedTypeSymbol attributeType =
INamedTypeSymbol? attributeType =
context.Compilation.GetTypeByMetadataName( AttributeTypeName );

if( attributeType.IsNullOrErrorType() ) {
Expand Down Expand Up @@ -73,7 +73,7 @@ INamedTypeSymbol attributeType

foreach( AttributeSyntax attribute in attributes ) {

ISymbol symbol = context
ISymbol? symbol = context
.SemanticModel
.GetSymbolInfo( attribute, context.CancellationToken )
.Symbol;
Expand All @@ -88,9 +88,9 @@ INamedTypeSymbol attributeType
}

SeparatedSyntaxList<AttributeArgumentSyntax> arguments =
attribute.ArgumentList.Arguments;
attribute.ArgumentList?.Arguments ?? default;

if( arguments.Count != 1 ) {
if( arguments == default || arguments.Count != 1 ) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public override void Initialize( AnalysisContext context ) {
public static void RegisterServiceLocatorAnalyzer(
CompilationStartAnalysisContext context
) {
INamedTypeSymbol factoryType = context.Compilation
INamedTypeSymbol? factoryType = context.Compilation
.GetTypeByMetadataName( TestServiceLocatorFactoryType );

if( factoryType.IsNullOrErrorType() ) {
Expand Down

0 comments on commit 63fed57

Please sign in to comment.