For unit testing I want to use the more common way creating a mock for my DbContext. To get this working I created a FakeDbSet that implements DbSet<TEntity>, IQueryable<TEntity> and IAsyncEnumerable<TEntity>. The FakeDbSet hold its data internal to fake the Add, Find and Remove methods. As I am using code first and .Include and .Theninclude to create joins, I need mock these methods. How I can do this?
Using the InMemory Database is no longer a option for me. I did run into multiple issues because parallel execution of tests and I dont want to test the entity framework core in my tests, I want to test my repository class. For unit testing the mocking way is really common.
This is my implementation for the mock and a example how i use it:
class FakeDbSet<TEntity> : DbSet<TEntity>, IQueryable<TEntity>, IAsyncEnumerable<TEntity> where TEntity : class { /// <summary> /// Static constructor. Determines the which properties are key properties /// </summary> static FakeDbSet() { var type = typeof(TEntity); foreach (var property in type .GetProperties() .Where(v => v.GetCustomAttributes(false).OfType<KeyAttribute>().Any())) { Keys.Add(property); } } /// <summary> /// Contains PropertyInfo objects for each of the key properties /// </summary> private static readonly List<PropertyInfo> Keys = new List<PropertyInfo>(); /// <summary> /// The data we will query against in a List object /// </summary> private readonly IList<TEntity> _data; /// <summary> /// The data we will query against in a IQueryable object /// </summary> private readonly IQueryable<TEntity> _queryable; /// <summary> /// A dictionary to look up the current status of an object /// </summary> private readonly Dictionary<TEntity, EntityStatus> _entityStatus = new Dictionary<TEntity, EntityStatus>(); /// <summary> /// Observable collection of data /// </summary> /// <summary> /// Constructor. Expects an IList of entity type /// that becomes the data store /// </summary> /// <param name="data"></param> public FakeDbSet(IList<TEntity> data) { _data = data; _entityStatus.Clear(); foreach (var item in data) { _entityStatus[item] = EntityStatus.Normal; } _queryable = data.AsQueryable(); // The fake provider wraps the real provider (for "List<TEntity") // so that it can log activities Provider = new FakeAsyncQueryProvider<TEntity>(_queryable.Provider); } /// <inheritdoc /> public override EntityEntry<TEntity> Add(TEntity entity) { _data.Add(entity); _entityStatus[entity] = EntityStatus.Added; return null; } /// <inheritdoc /> public override async Task<EntityEntry<TEntity>> AddAsync(TEntity entity, CancellationToken cancellationToken = new CancellationToken()) { return await Task.FromResult(Add(entity)); } /// <inheritdoc /> public override Task AddRangeAsync(params TEntity[] entities) { throw new NotImplementedException(); } /// <summary> /// Implements the Find function of IdbSet. /// Depends on the keys collection being /// set to the key types of this entity /// </summary> /// <param name="keyValues"></param> /// <returns></returns> public override TEntity Find(params object[] keyValues) { if (keyValues.Length != Keys.Count) { throw new ArgumentException( string.Format("Must supply {0} key values", Keys.Count),"keyValues" ); } var query = _queryable; var parameterExpression = Expression.Parameter(typeof(TEntity), "v"); for (int i = 0; i < Keys.Count; i++) { var equalsExpression = Expression.Equal( // key property Expression.Property(parameterExpression, Keys[i]), // key value Expression.Constant(keyValues[i], Keys[i].PropertyType) ); var whereClause = (Expression<Func<TEntity, bool>>) Expression.Lambda( equalsExpression, new ParameterExpression[] {parameterExpression} ); query = query.Where(whereClause); } var result = query.ToList(); return result.SingleOrDefault(); } public override async Task<TEntity> FindAsync(params object[] keyValues) { return await new Task<TEntity>(() => Find(keyValues)); } /// <summary> /// Implements the Remove function of IDbSet /// </summary> /// <param name="entity"></param> /// <returns></returns> public override EntityEntry<TEntity> Remove(TEntity entity) { _data.Remove(entity); _entityStatus[entity] = EntityStatus.Deleted; return null; } public IEnumerator<TEntity> GetEnumerator() { return _queryable.GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return _queryable.GetEnumerator(); } public Type ElementType => _queryable.ElementType; public Expression Expression => _queryable.Expression; public IQueryProvider Provider { get; } public enum EntityStatus { None, Added, Deleted, Normal } /// <inheritdoc /> IAsyncEnumerator<TEntity> IAsyncEnumerable<TEntity>.GetEnumerator() { return new FakeAsyncEnumerator<TEntity>(_queryable.GetEnumerator()); } } public class DataContextSubstitute { /// <summary> /// Create a mock of DataContext that contains some test data /// </summary> /// <returns></returns> public IDataContext MockDataContext() { var mockContext = Substitute.For<IDataContext>(); var users = MockUsers(mockContext); return mockContext; } private static IList<User> MockUsers(IDataContext mockContext) { var salt = TokenUtility.GenerateToken(); var password = PasswordEncrypter.Encrypt("secret", salt); var users = Builder<User>.CreateListOfSize(10).All() .With(u => u.EmailConfirmed = true) .With(u => u.Active = true) .With(u => u.Banned = false) .With(u => u.Salt = salt) .With(u => u.Password = password) .With(u => u.Token = TokenUtility.GenerateToken()) .TheFirst(1) .With(u => u.UserName = "admin") .Build(); mockContext.Users.Returns(new FakeDbSet<User>(users)); return users; } }