Search code examples
linq-to-sqlexpression-trees

C# Traverse expression tree to extract table names


I need a way to traverse a LINQ-to-SQL expression tree to extract the table name(s) in the query. Even just the first table used in the query would probably be sufficient.

Example:

var query = from c in Db.Customers select c;

And the ideal function:

string TableName = ExtractTablesFromQuery(query);

Would return string "Customers"


Solution

  • LINQ to SQL doesn't expose this functionality for you so you have two options.

    Use the dataContext.GetCommand(myQuery) function and parse the TSQL

    This can get a little tricky with joins etc. but will guarantee you get the exact table names that are going to be involved.

    Visit the expression tree yourself

    This isn't too difficult but has the problem in that LINQ to SQL infers and optimizes which tables to actually use so you won't get a 100% accurate result of what will happen. e.g. if you joined a table but didn't return any of the results it would get optimized out but you wouldn't know this by visiting the expression tree unless you optimized exactly like LINQ to SQL does (which would be a lot of work).

    If you want to try #2 anyway here's an example to get you started:

    public static class TableFinder
    {
        public static IEnumerable<string> GetTableNames(this DataContext context, IQueryable queryable) {
            var visitor = new TableFindingVisitor(context.Mapping);
            visitor.Visit(queryable.Expression);
            return visitor.Tables.Select(t => t.TableName).Distinct().AsEnumerable();
        }
    
        class TableFindingVisitor : ExpressionVisitor
        {
            private readonly HashSet<MetaTable> foundTables = new HashSet<MetaTable>();
            private readonly MetaModel mapping;
    
            public TableFindingVisitor(MetaModel mapping) {
                this.mapping = mapping;
            }
    
            public override Expression Visit(Expression node) {
                return base.Visit(node);
            }
    
            protected override Expression VisitConstant(ConstantExpression node) {
                if (node.Type.GetGenericTypeDefinition() == typeof(Table<>))
                    CheckType(node.Type.GetGenericArguments()[0]);
                return base.VisitConstant(node);
            }
    
            protected override Expression VisitMember(MemberExpression node) {
                CheckType(node.Member.DeclaringType);
                return base.VisitMember(node);
            }
    
            public IEnumerable<MetaTable> Tables { get { return foundTables; } }
    
            private void CheckType(Type t) {
                var table = mapping.GetTable(t);
                if (table != null && !foundTables.Contains(table))
                    foundTables.Add(table);
            }
        }
    

    To use this you would foreach over the results from dataContext.GetTables(myQuery);