31

I would like to pass a value to the ctor of a DbContext and then have that value enforce "filtering" on the related DbSets. Is this possible...or is there a better approach?

Code might look like this:

class Contact {
  int ContactId { get; set; }
  int CompanyId { get; set; }
  string Name { get; set; }
}

class ContactContext : DbContext {
  public ContactContext(int companyId) {...}
  public DbSet<Contact> Contacts { get; set; }
}

using (var cc = new ContactContext(123)) {
  // Would only return contacts where CompanyId = 123
  var all = (from i in cc.Contacts select i);

  // Would automatically set the CompanyId to 123
  var contact = new Contact { Name = "Doug" };
  cc.Contacts.Add(contact);
  cc.SaveChanges();

  // Would throw custom exception
  contact.CompanyId = 456;
  cc.SaveChanges;
}
Tsahi Asher
  • 1,767
  • 15
  • 28
Doug Clutter
  • 3,646
  • 2
  • 29
  • 31

2 Answers2

47

I decided to implement a custom IDbSet to deal with this. To use this class, you pass in a DbContext, a filter expression, and (optionally) an Action to initialize new entities so they meet the filter criteria.

I've tested enumerating the set and using the Count aggregate functions. Both of them modify the SQL that is generated so they should be much more efficient than filtering on the client.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;


namespace MakeMyPledge.Data
{
    class FilteredDbSet<TEntity> : IDbSet<TEntity>, IOrderedQueryable<TEntity>, IOrderedQueryable, IQueryable<TEntity>, IQueryable, IEnumerable<TEntity>, IEnumerable, IListSource
        where TEntity : class
    {
        private readonly DbSet<TEntity> Set;
        private readonly IQueryable<TEntity> FilteredSet;
        private readonly Action<TEntity> InitializeEntity;

        public FilteredDbSet(DbContext context)
            : this(context.Set<TEntity>(), i => true, null)
        {
        }

        public FilteredDbSet(DbContext context, Expression<Func<TEntity, bool>> filter)
            : this(context.Set<TEntity>(), filter, null)
        {
        }

        public FilteredDbSet(DbContext context, Expression<Func<TEntity, bool>> filter, Action<TEntity> initializeEntity)
            : this(context.Set<TEntity>(), filter, initializeEntity)
        {
        }

        private FilteredDbSet(DbSet<TEntity> set, Expression<Func<TEntity, bool>> filter, Action<TEntity> initializeEntity)
        {
            Set = set;
            FilteredSet = set.Where(filter);
            MatchesFilter = filter.Compile();
            InitializeEntity = initializeEntity;
        }

        public Func<TEntity, bool> MatchesFilter { get; private set; }

        public void ThrowIfEntityDoesNotMatchFilter(TEntity entity)
        {
            if (!MatchesFilter(entity))
                throw new ArgumentOutOfRangeException();
        }

        public TEntity Add(TEntity entity)
        {
            DoInitializeEntity(entity);
            ThrowIfEntityDoesNotMatchFilter(entity);
            return Set.Add(entity);
        }

        public TEntity Attach(TEntity entity)
        {
            ThrowIfEntityDoesNotMatchFilter(entity);
            return Set.Attach(entity);
        }

        public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity
        {
            var entity = Set.Create<TDerivedEntity>();
            DoInitializeEntity(entity);
            return (TDerivedEntity)entity;
        }

        public TEntity Create()
        {
            var entity = Set.Create();
            DoInitializeEntity(entity);
            return entity;
        }

        public TEntity Find(params object[] keyValues)
        {
            var entity = Set.Find(keyValues);
            if (entity == null)
                return null;

            // If the user queried an item outside the filter, then we throw an error.
            // If IDbSet had a Detach method we would use it...sadly, we have to be ok with the item being in the Set.
            ThrowIfEntityDoesNotMatchFilter(entity);
            return entity;
        }

        public TEntity Remove(TEntity entity)
        {
            ThrowIfEntityDoesNotMatchFilter(entity);
            return Set.Remove(entity);
        }

        /// <summary>
        /// Returns the items in the local cache
        /// </summary>
        /// <remarks>
        /// It is possible to add/remove entities via this property that do NOT match the filter.
        /// Use the <see cref="ThrowIfEntityDoesNotMatchFilter"/> method before adding/removing an item from this collection.
        /// </remarks>
        public ObservableCollection<TEntity> Local { get { return Set.Local; } }

        IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator() { return FilteredSet.GetEnumerator(); }

        IEnumerator IEnumerable.GetEnumerator() { return FilteredSet.GetEnumerator(); }

        Type IQueryable.ElementType { get { return typeof(TEntity); } }

        Expression IQueryable.Expression { get { return FilteredSet.Expression; } }

        IQueryProvider IQueryable.Provider { get { return FilteredSet.Provider; } }

        bool IListSource.ContainsListCollection { get { return false; } }

        IList IListSource.GetList() { throw new InvalidOperationException(); }

        void DoInitializeEntity(TEntity entity)
        {
            if (InitializeEntity != null)
                InitializeEntity(entity);
        }
    }
}
Doug Clutter
  • 3,646
  • 2
  • 29
  • 31
  • 3
    This is great! Is there any way to get this to filter lazy loaded items – maxfridbe Mar 08 '12 at 19:25
  • 2
    This example works great, but I found a case that causes me problems... Includes are not applied on such a dataset. I tried to create an extension method for this (`public static IQueryable Include(this IDbSet dbSet, Expression> expression)`) where I check the type of dbSet. If it's my type, I call the `Include` method on the Original `DbSet`, if it's not, I call it on `IQueryable`. The problem is that `Include` has to be called on the `IDbSet`, not on something else (e.g. the result of `AsNoTracking`)... Any ideas? – ghigad Jan 26 '15 at 18:45
  • This works! However the FilteredSet = set.Where(filter) behavior is different from the original dbset, consider using a property? – Yiping Mar 17 '17 at 13:12
  • This answer was very helpful for me, thanks. There's just one issue, that you'll find Includes don't work anymore `ctx.WrappedDbSet.Include(x => x.someProperty)`. But it's easy to add support-simply add a method `public IDbSet Include(string path) {} ` and call Include on FilteredSet. –  Dec 14 '18 at 16:53
5

EF doesn't have any "filter" feature. You can try to achive something like that by inheriting custom DbSet but I think it will still be problematic. For example DbSet directly implements IQueryable so there is probably no way how to include custom condition.

This will require some wrapper which will handle these requirements (can be repository):

  • Condition in select can be handled by wrapping method around DbSet which will add Where condition
  • Insert can be handled by wrapping method as well
  • Update must be handled by overriding SaveChanges and using context.ChangeTracker to get all updated entities. Then you can check if entities were modified.

By wrapper I do not mean custom DbSet implementation - that is too complex:

public class MyDal
{
    private DbSet<MyEntity> _set;

    public MyDal(DbContext context)
    {
        _set = context.Set<MyEntity>();
    }

    public IQueryable<MyEntity> GetQuery()
    {
        return _set.Where(e => ...);
    }

    // Attach, Insert, Delete
}
Ladislav Mrnka
  • 360,892
  • 59
  • 660
  • 670
  • Along with what Ladislav said, check out the Specification Pattern: http://devlicio.us/blogs/jeff_perrin/archive/2006/12/13/the-specification-pattern.aspx Without getting into a ton of detail, it may help with consolidating your filters. – DDiVita Apr 16 '11 at 00:43
  • 1
    @DDiVita Not sure that post helps me much. I want to inject an additional filter into the SQL WHERE clause that EF generates. – Doug Clutter Apr 20 '11 at 16:42
  • @Ladislav - I have tried to create a new IDbSet<> class that basically wraps the DbSet<> provided by DbContext. I overrode the GetEnumerator as follows: ` public IEnumerator GetEnumerator() { return (from i in mDbSet where i.AccountNumber.Value == mAccountNumber.Value select i).GetEnumerator(); }` This worked for queries that enumerate the set...but not for queries that aggregate like MySet.Count(). Any thoughts? – Doug Clutter Apr 20 '11 at 16:44
  • @Doug: You took it too complex. I will add some example to my answer. – Ladislav Mrnka Apr 20 '11 at 19:15
  • @Ladislav - I agree; there is really no reason to do a full IDbSet implementation. I've found overriding the Expression property to be exceptionally difficult. – Doug Clutter Apr 21 '11 at 11:01
  • @Ladislav - While a full IDbSet implementation may be overkill, I've implemented a generic version that can be used anywhere. Thanks for helping put me on the right track. – Doug Clutter Apr 21 '11 at 14:51