Recently, a friend of mine asked me if there’s a better way to run integration tests without using in-memory DB because in-memory DB does not behave exactly like the real thing. In my opinion, the repository should be simple and not contain any business logic. Hence mocking the data in the in-memory DB for the integration tests should be sufficient for most of the time.
However, we do also want to have integration tests that use the real DB to ensure everything does work together. One method we use in our project at work is to use a dockerized SQL Server for the integration tests. In this post, I’m going to show how we set this up.
Sample WebAPI
We have a sample WebAPI that uses Entity Framework to return Blogs like this
public class BloggingContext : DbContext
{
public DbSet<Blog> Blogs { get; set; }
public DbSet<Post> Posts { get; set; }
public BloggingContext(DbContextOptions<BloggingContext> options)
: base(options)
{
}
protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder
.Entity<Blog>()
.Property(x => x.BlogId)
.ValueGeneratedNever();
modelBuilder
.Entity<Post>()
.Property(x => x.PostId)
.ValueGeneratedNever();
}
}
public class Blog
{
[Key]
public Guid BlogId { get; set; }
public string Name { get; set; }
public DateTime PublishedDate { get; set; }
public ICollection<Post> Posts { get; set; } = new List<Post>();
}
public class Post
{
[Key]
public Guid PostId { get; set; }
public string Title { get; set; }
public DateTime PublishedDate { get; set; }
public Guid BlogId { get; set; }
public Blog Blog { get; set; }
}
[ApiController]
[Route("[controller]")]
public class SampleController : ControllerBase
{
private readonly BloggingContext _bloggingContext;
public SampleController(BloggingContext bloggingContext)
{
_bloggingContext = bloggingContext;
}
[HttpGet]
public async Task<IActionResult> Get()
{
var blogs = _bloggingContext.Blogs.Include(x => x.Posts);
return Ok(blogs.Select(x => new BlogDto
{
Id = x.BlogId,
Name = x.Name,
Posts = x.Posts.Select(p => new PostDto
{
Title = p.Title
})
}));
}
public class BlogDto
{
public Guid Id { get; set; }
public string Name { get; set; }
public IEnumerable<PostDto> Posts { get; set; } = new List<PostDto>();
}
public class PostDto {
public string Title { get; set; }
}
}
Setting up Docker
- Install Docker engine if you don’t already have it. You can install the engine using Docker Desktop
- Create a new
docker-compose.yml
file
version: "3.9"
services:
db:
image: "mcr.microsoft.com/mssql/server:2022-latest"
container_name: sql-server-2022
environment:
SA_PASSWORD: "P@ssw0rd"
ACCEPT_EULA: "Y"
MSSQL_PID: "Express"
ports:
- "1433:1433"
- Start the container. Open a terminal, navigate to where the
docker-compose.yml
file is and rundocker compose up
. This might take some time to download the docker image if running the first time. You should see something like this once it’s ready
Once it finishes, open SSMS and test that you can connect to the SQL Server on the docker container with the password from the yml file above.
Setting up Integration Tests
Add an appsetting.IntegrationTest.json
for the tests
{
"ConnectionStrings": {
"DefaultConnection": "Server=localhost,1433;Initial Catalog=BlogsDb;User=sa;Password=P@ssw0rd;"
}
}
Add a static class Seeder.cs
to seed some dummy data for the tests
public static class Seeder
{
public static void Seed(this BloggingContext dbContext)
{
if (!dbContext.Blogs.Any())
{
var fixture = new Fixture();
fixture.Behaviors
.OfType<ThrowingRecursionBehavior>()
.ToList()
.ForEach(b => fixture.Behaviors.Remove(b));
fixture.Behaviors.Add(new OmitOnRecursionBehavior(1));
dbContext.AddRange(fixture.CreateMany<Blog>(100));
dbContext.SaveChanges();
}
}
}
Add a custom web application factory like below
public class CustomWebApplicationFactory<TStartup>
: WebApplicationFactory<TStartup> where TStartup : class
{
public static string GetProjectPath()
{
var projectName = typeof(TStartup).Assembly.GetName().Name;
var applicationBasePath = AppContext.BaseDirectory;
var directoryInfo = new DirectoryInfo(applicationBasePath);
do
{
directoryInfo = directoryInfo.Parent;
var projectDirectoryInfo = new DirectoryInfo(directoryInfo.FullName);
if (projectDirectoryInfo.Exists)
if (new FileInfo(Path.Combine(projectDirectoryInfo.FullName, projectName, $"{projectName}.csproj")).Exists)
return Path.Combine(projectDirectoryInfo.FullName, projectName);
}
while (directoryInfo.Parent != null);
throw new Exception($"Project root could not be located using the application root {applicationBasePath}.");
}
protected override void ConfigureWebHost(IWebHostBuilder builder)
{
var contentRoot = GetProjectPath();
var configuration = new ConfigurationBuilder()
.SetBasePath(contentRoot)
.AddJsonFile("appsettings.IntegrationTest.json")
.AddEnvironmentVariables()
.Build();
builder
.ConfigureServices(services =>
{
var descriptor = services.SingleOrDefault(
d => d.ServiceType ==
typeof(DbContextOptions<BloggingContext>));
services.Remove(descriptor);
services.AddDbContext<BloggingContext>(options =>
{
options.UseSqlServer(configuration.GetConnectionString("DefaultConnection"));
});
var sp = services.BuildServiceProvider();
using var scope = sp.CreateScope();
var scopedServices = scope.ServiceProvider;
var db = scopedServices.GetRequiredService<BloggingContext>();
var logger = scopedServices
.GetRequiredService<ILogger<CustomWebApplicationFactory<TStartup>>>();
db.Database.EnsureCreated();
try
{
Seeder.Seed(db);
}
catch (Exception ex)
{
logger.LogError(ex, "An error occurred seeding the " +
"database with test messages. Error: {Message}", ex.Message);
}
});
}
}
Adding a test
The following is a sample test that performs a GET request to get all the blogs.
[TestFixture]
public class SampleTests
{
private readonly HttpClient _client;
private readonly CustomWebApplicationFactory<Startup> _factory;
public SampleTests()
{
_factory = new CustomWebApplicationFactory<Startup>();
_client = _factory.CreateClient(new WebApplicationFactoryClientOptions
{
AllowAutoRedirect = false
});
}
[Test]
public async Task Test_Get_Blogs()
{
// Arrange
var request = "/sample";
// Act
var response = await _client.GetAsync(request);
// Assert
response.EnsureSuccessStatusCode();
var jsonResult = await response.Content.ReadAsStringAsync();
var blogs = JsonConvert.DeserializeObject<IEnumerable<BlogDto>>(jsonResult);
blogs.Should().HaveCount(100);
}
}