June 2009 - Posts

T4 for View Generation

EF team blogged about it here and I think that it’s very useful.

First, I think that it’s easier to use than Edmgen.

Moreover, contrary to EdmGen, it allows you to use embedded metadata artifacts.

If you want to use the MS T4 template with VS 2010, don’t forget to change the xml namespaces.

XNamespace edmxns = "http://schemas.microsoft.com/ado/2008/10/edmx";
XNamespace csdlns = "http://schemas.microsoft.com/ado/2008/09/edm";
XNamespace mslns = "http://schemas.microsoft.com/ado/2008/09/mapping/cs";
XNamespace ssdlns = "http://schemas.microsoft.com/ado/2009/02/edm/ssdl";

So you should have the following:

<#@ template language="C#" hostspecific="true"#>
<#@ output extension=".cs" #>

<#@ assembly name="System.Core" #>
<#@ assembly name="System.Data" #>
<#@ assembly name="System.Data.Entity" #>
<#@ assembly name="System.Data.Entity.Design" #>
<#@ assembly name="System.Xml" #>
<#@ assembly name="System.Xml.Linq" #>

<#@ import namespace="System" #>
<#@ import namespace="System.IO" #>
<#@ import namespace="System.Diagnostics" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ import namespace="System.Data.Entity.Design" #>
<#@ import namespace="System.Data.Metadata.Edm" #>
<#@ import namespace="System.Data.Mapping" #>
<#@ import namespace="System.Xml" #>
<#@ import namespace="System.Xml.Linq" #>

<#
    // Find EDMX file to process: Model1.Views.tt generates views for Model1.EDMX
    string edmxFileName = Path.GetFileNameWithoutExtension(this.Host.TemplateFile).ToLowerInvariant().Replace(".views", "") + ".edmx";
    string edmxFilePath = Path.Combine(Path.GetDirectoryName(this.Host.TemplateFile), edmxFileName);
    if (File.Exists(edmxFilePath))
    {
        // Call helper class to generate pre-compiled views and write to output
        this.WriteLine(GenerateViews(edmxFilePath));
    }
    else
    {
        this.Error(String.Format("No views were generated. Cannot find file {0}. Ensure the project has an EDMX file and the file name of the .tt file is of the form [edmx-file-name].Views.tt", edmxFilePath));
    }
    // All done!
#>

<#+
    private String GenerateViews(string edmxFilePath)
    {
        String generatedViews = String.Empty;
        try
        {
            using (StreamWriter writer = new StreamWriter(new MemoryStream()))
            {
                XmlReader csdlReader = null;
                XmlReader mslReader = null;
                XmlReader ssdlReader = null;

                // Crack open the EDMX file and get readers over the CSDL, MSL and SSDL portions
                GetConceptualMappingAndStorageReaders(edmxFilePath, out csdlReader, out mslReader, out ssdlReader);

                // Initialize item collections
                EdmItemCollection edmItems = new EdmItemCollection(new XmlReader[] { csdlReader });
                StoreItemCollection storeItems = new StoreItemCollection(new XmlReader[] { ssdlReader });
                StorageMappingItemCollection mappingItems = new StorageMappingItemCollection(edmItems, storeItems, new XmlReader[] { mslReader });

                // Initialize the view generator to generate views in C#
                EntityViewGenerator viewGenerator = new EntityViewGenerator();
                viewGenerator.LanguageOption = LanguageOption.GenerateCSharpCode;
                IList<EdmSchemaError> errors = viewGenerator.GenerateViews(mappingItems, writer);

                foreach (EdmSchemaError e in errors)
                {
                    // log error
                    this.Error(e.Message);
                }

                MemoryStream memStream = writer.BaseStream as MemoryStream;
                generatedViews = Encoding.UTF8.GetString(memStream.ToArray());
            }
        }
        catch (Exception ex)
        {
            // log error
            this.Error(ex.ToString());
        }

        return generatedViews;
    }

    private void GetConceptualMappingAndStorageReaders(string edmxFile, out XmlReader csdlReader, out XmlReader mslReader, out XmlReader ssdlReader)
    {
        csdlReader = null;
        mslReader = null;
        ssdlReader = null;

        XNamespace edmxns = "http://schemas.microsoft.com/ado/2008/10/edmx";
        XNamespace csdlns = "http://schemas.microsoft.com/ado/2008/09/edm";
        XNamespace mslns = "http://schemas.microsoft.com/ado/2008/09/mapping/cs";
        XNamespace ssdlns = "http://schemas.microsoft.com/ado/2009/02/edm/ssdl";

        XDocument edmxDoc = XDocument.Load(edmxFile);
        if (edmxDoc != null)
        {
            XElement edmxNode = edmxDoc.Element(edmxns + "Edmx");
            if (edmxNode != null)
            {
                XElement runtimeNode = edmxNode.Element(edmxns + "Runtime");
                if (runtimeNode != null)
                {
                    // Create XmlReader over CSDL in EDMX
                    XElement conceptualModelsNode = runtimeNode.Element(edmxns + "ConceptualModels");
                    if (conceptualModelsNode != null)
                    {
                        XElement csdlContent = conceptualModelsNode.Element(csdlns + "Schema");
                        if (csdlContent != null)
                        {
                            csdlReader = csdlContent.CreateReader();
                        }
                    }

                    // Create XmlReader over MSL in EDMX
                    XElement mappingsNode = runtimeNode.Element(edmxns + "Mappings");
                    if (mappingsNode != null)
                    {
                        XElement mslContent = mappingsNode.Element(mslns + "Mapping");
                        if (mslContent != null)
                        {
                            mslReader = mslContent.CreateReader();
                        }
                    }

                    // Create XmlReader over SSDL in EDMX
                    XElement storageModelsNode = runtimeNode.Element(edmxns + "StorageModels");
                    if (storageModelsNode != null)
                    {
                        XElement ssdlContent = storageModelsNode.Element(ssdlns + "Schema");
                        if (ssdlContent != null)
                        {
                            ssdlReader = ssdlContent.CreateReader();
                        }
                    }
                }
            }
        }
    }
#>

Posted by Matthieu MEZIL | with no comments

POCO T4

With EF4 features CTP1, we have a great POCO template which is divided in two tt files to allow us to have the entitiy POCO classes and the context in different projects.

It’s very very cool.

However, I think that it’s a shame not to have an interface for the context particularly to mock it.

So I change the .Context.tt like this:

 

<#@ template language="C#" debug="false" hostspecific="true"#>
<#@ include file="EF.Utility.ctp.CS.ttinclude"#><#@
output extension=".cs"#><#
// Copyright (c) Microsoft Corporation.  All rights reserved.

CodeGenerationTools code = new CodeGenerationTools(this);
MetadataLoader loader = new MetadataLoader(this);
CodeRegion region = new CodeRegion(this, 1);
MetadataTools ef = new MetadataTools(this);

string inputFile = @"Northwind.edmx";
EdmItemCollection ItemCollection = loader.CreateEdmItemCollection(inputFile);
string namespaceName = code.VsNamespaceSuggestion();

TemplateFileManager fileManager = TemplateFileManager.Create(this);
EntityContainer container = ItemCollection.GetItems<EntityContainer>().FirstOrDefault();
if (container == null)
{
    return "// No EntityContainer exists in the model, so no code was generated";
}

// Emit Entity Types
    string interfaceName = "I" + code.Escape(container);
    fileManager.StartNewFile(interfaceName + ".cs");
#>
using System;
using System.Collections.Generic;
using System.Data.Objects;
using System.Data.EntityClient;
using Entities;

<#
if (!String.IsNullOrEmpty(namespaceName))
{
#>
namespace <#=code.EscapeNamespace(namespaceName)#>
{
<#
    PushIndent(CodeRegion.GetIndent(1));
}
#>
<#=Accessibility.ForType(container)#> interface <#=interfaceName#>
{
<#
        foreach (EntitySet entitySet in container.BaseEntitySets.OfType<EntitySet>())
        {
#>
    IObjectSet<<#=code.Escape(entitySet.ElementType)#>> <#=code.Escape(entitySet)#> { get; }
<#
        }
#>
<#
        if (container.FunctionImports.Any())
        {
#>

<#
        }
        foreach (EdmFunction edmFunction in container.FunctionImports)
        {
            var parameters = FunctionImportParameter.Create(edmFunction.Parameters, code, ef);
            string paramList = String.Join(", ", parameters.Select(p => p.FunctionParameterType + " " + p.FunctionParameterName).ToArray());
            if(edmFunction.ReturnParameter == null)
            {
                continue;
            }
            string returnTypeElement = code.Escape(ef.GetElementType(edmFunction.ReturnParameter.TypeUsage));

#>
    IEnumerable<<#=returnTypeElement#>> <#=code.Escape(edmFunction)#>(<#=paramList#>);
<#
        }
#>
}
<#
if (!String.IsNullOrEmpty(namespaceName))
{
    PopIndent();
#>
}
<#
}
    fileManager.WriteFiles();
#>
//------------------------------------------------------------------------------
// <auto-generated>
//     This code was generated from a template.
//
//     Changes to this file may cause incorrect behavior and will be lost if
//     the code is regenerated.
// </auto-generated>
//------------------------------------------------------------------------------

using System;
using System.Collections.Generic;
using System.Data.Objects;
using System.Data.EntityClient;
using Entities;

<#
if (!String.IsNullOrEmpty(namespaceName))
{
#>
namespace <#=code.EscapeNamespace(namespaceName)#>
{
<#
    PushIndent(CodeRegion.GetIndent(1));
}
#>
<#=Accessibility.ForType(container)#> partial class <#=code.Escape(container)#> : ObjectContext, <#= interfaceName #>
{
    public const string ConnectionString = "name=<#=container.Name#>";
    public const string ContainerName = "<#=container.Name#>";

    #region Constructors

    public <#=code.Escape(container)#>()
        : base(ConnectionString, ContainerName)
    {
        ContextOptions.DeferredLoadingEnabled = true;
    }

    public <#=code.Escape(container)#>(string connectionString)
        : base(connectionString, ContainerName)
    {
        ContextOptions.DeferredLoadingEnabled = true;
    }

    public <#=code.Escape(container)#>(EntityConnection connection)
        : base(connection, ContainerName)
    {
        ContextOptions.DeferredLoadingEnabled = true;
    }

    #endregion

<#
        region.Begin("ObjectSet Properties");

        foreach (EntitySet entitySet in container.BaseEntitySets.OfType<EntitySet>())
        {
#>

    <#=AccessibilityAndVirtual(Accessibility.ForReadOnlyProperty(entitySet))#> ObjectSet<<#=code.Escape(entitySet.ElementType)#>> <#=code.Escape(entitySet)#>
    {
        get { return <#=code.FieldName(entitySet) #>  ?? (<#=code.FieldName(entitySet)#> = CreateObjectSet<<#=code.Escape(entitySet.ElementType)#>>("<#=entitySet.Name#>")); }
    }
    private ObjectSet<<#=code.Escape(entitySet.ElementType)#>> <#=code.FieldName(entitySet)#>;
    IObjectSet<<#=code.Escape(entitySet.ElementType)#>> <#=interfaceName#>.<#=code.Escape(entitySet)#>
    {
        get { return <#=code.Escape(entitySet)#>; }
    }
<#
        }

        region.End();
#>

<#
        region.Begin("Function Imports");

        foreach (EdmFunction edmFunction in container.FunctionImports)
        {
            var parameters = FunctionImportParameter.Create(edmFunction.Parameters, code, ef);
            string paramList = String.Join(", ", parameters.Select(p => p.FunctionParameterType + " " + p.FunctionParameterName).ToArray());
            if(edmFunction.ReturnParameter == null)
            {
                continue;
            }
            string returnTypeElement = code.Escape(ef.GetElementType(edmFunction.ReturnParameter.TypeUsage));

#>
    <#=AccessibilityAndVirtual(Accessibility.ForMethod(edmFunction))#> ObjectResult<<#=returnTypeElement#>> <#=code.Escape(edmFunction)#>(<#=paramList#>)
    {
<#
            foreach (var parameter in parameters)
            {
                if (!parameter.NeedsLocalVariable)
                {
                    continue;
                }
#>

        ObjectParameter <#=parameter.LocalVariableName#>;

        if (<#=parameter.IsNullableOfT ? parameter.FunctionParameterName + ".HasValue" : parameter.FunctionParameterName + " != null"#>)
        {
            <#=parameter.LocalVariableName#> = new ObjectParameter("<#=parameter.EsqlParameterName#>", <#=parameter.FunctionParameterName#>);
        }
        else
        {
            <#=parameter.LocalVariableName#> = new ObjectParameter("<#=parameter.EsqlParameterName#>", typeof(<#=parameter.RawClrTypeName#>));
        }
<#
            }
#>
        return base.ExecuteFunction<<#=returnTypeElement#>>("<#=edmFunction.Name#>"<#=code.StringBefore(", ", string.Join(", ", parameters.Select(p => p.ExecuteParameterName).ToArray()))#>);
    }
    IEnumerable<<#=returnTypeElement#>> <#=interfaceName#>.<#=code.Escape(edmFunction)#>(<#=paramList#>)
    {
        return <#=code.Escape(edmFunction)#>(<#=String.Join(", ", parameters.Select(p => p.FunctionParameterName).ToArray())#>);
    }
<#
        }

        region.End();
#>
}
<#
if (!String.IsNullOrEmpty(namespaceName))
{
    PopIndent();
#>
}
<#
}
#>
<#+
string AccessibilityAndVirtual(string accessibility)
{
    if (accessibility != "private")
    {
        return accessibility + " virtual";
    }

    return accessibility;
}
#>

 

Enjoy Smile

Posted by Matthieu MEZIL | with no comments

SubObjectSet

With EF, when you use TPH or TPC inheritance mapping scenarii, the EntitySet is on the base class.

As I mentioned often in the past with EF v1, you can add a property in your context which returns the EntitySet.OfType<MySubType>().

Ok it’s interesting but… In EF v1, the EntitySet is an ObjectQuery<T> property and our property also but in EF v2 the EntitySet is an ObjectSet<T>. This class implements the IObjectSet<T> interface which has three methods to add, attach and delete entities.

One guy tells me that he wants to be able to use these methods directly on the “sub EntitySet” property.

To realize it, I made the following class:

public class SubObjectSet<TBase, TInherited> : ObjectQuery<TInherited>, IObjectSet<TInherited>

    where TBase : class

    where TInherited : class, TBase

{

    public ObjectSet<TBase> ObjectSet { get; private set; }

 

    public SubObjectSet(ObjectSet<TBase> objectSet)

        : base(objectSet.OfType<TInherited>().CommandText, objectSet.Context)

    {

        ObjectSet = objectSet;

    }

 

    #region IObjectSet<TInherited> Members

    public void AddObject(TInherited entity)

    {

        ObjectSet.AddObject(entity);

    }

    public void Attach(TInherited entity)

    {

        ObjectSet.Attach(entity);

    }

    public void DeleteObject(TInherited entity)

    {

        ObjectSet.DeleteObject(entity);

    }

    #endregion

}

Posted by Matthieu MEZIL | with no comments

Entity Framework v2: How to get only one entity easier with EF4

Alex James wrote an extension method which allows to get only one entity from a query and the entity key.

If we have the key, I think it’s useless to allow it for all queries and it’s useful only for EntitySet. With EF4, this extension method can be applied on ObjectSet class instead of ObjectQuery class.

// In the first version of Entity Framework, the ObjectSet class doesn’t exist, EntitySet were some ObjectQuery.

// ObjectSet<T> class inherits from ObjectQuery<T>

This simplifies the code because we can directly use the (Try)GetObjectByKey method:

public static class ObjectSetExtension

{

    public static T Get<T>(this ObjectSet<T> objectSet, object key) where T : class

    {

        object value;

        objectSet.Context.TryGetObjectByKey(new EntityKey(string.Concat(objectSet.Context.DefaultContainerName, ".", objectSet.EntitySet.Name), objectSet.EntitySet.ElementType.KeyMembers.Single().Name, key), out value);

        return (T)value;

    }

    public static T Get<T>(this ObjectSet<T> objectSet, params EntityKeyMember[] keys) where T : class

    {

        object value;

        objectSet.Context.TryGetObjectByKey(new EntityKey(string.Concat(objectSet.Context.DefaultContainerName, ".", objectSet.EntitySet.Name), keys), out value);

        return (T)value;

    }

}

We can use the second extension method for entities with composite key.

Entity Framework: Undo Redo v2

After my first Undo Redo POC version, one of my customers wanted to be able to manage many actions per Undo / Redo.

So I added two extension methods: BeginGroupOfUndoActions and EndGroupOfUndoActions.

My code is now this one:

public static class ObjectContextExtension

{

    private static Dictionary<ObjectContext, ObjectContextUndoRedo> _objectContextUndoRedo = new Dictionary<ObjectContext, ObjectContextUndoRedo>();

 

    public static void ActivateUndoRedoTracking(this ObjectContext context, int undoStackLength)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        _objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);

        if (objectContextUndoRedo == null)

        {

            objectContextUndoRedo = new ObjectContextUndoRedo { Context = context };

            _objectContextUndoRedo.Add(context, objectContextUndoRedo);

        }

        objectContextUndoRedo.ActivateUndoRedoTracking(undoStackLength);

    }

 

    public static bool CanUndo(this ObjectContext context)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        _objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);

        if (objectContextUndoRedo == null)

            throw new InvalidOperationException();

        return objectContextUndoRedo.CanUndo;

    }

 

    public static void Undo(this ObjectContext context)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        _objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);

        if (objectContextUndoRedo == null)

            throw new InvalidOperationException();

        objectContextUndoRedo.Undo();

    }

 

    public static bool CanRedo(this ObjectContext context)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        _objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);

        if (objectContextUndoRedo == null)

            throw new InvalidOperationException();

        return objectContextUndoRedo.CanRedo;

    }

 

    public static void Redo(this ObjectContext context)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        _objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);

        if (objectContextUndoRedo == null)

            throw new InvalidOperationException();

        objectContextUndoRedo.Redo();

    }

 

    public static void BeginGroupOfUndoActions(this ObjectContext context)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        _objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);

        if (objectContextUndoRedo == null)

            throw new InvalidOperationException();

        objectContextUndoRedo.MultipleActions = true;

    }

 

    public static void EndGroupOfUndoActions(this ObjectContext context)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        _objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);

        if (objectContextUndoRedo == null)

            throw new InvalidOperationException();

        objectContextUndoRedo.MultipleActions = false;

    }

 

    private class ObjectContextUndoRedo

    {

        private List<List<UndoRedoAction>> _undo, _redo;

        private int _undoStackLength;

        private bool _trackChanges;

 

        public ObjectContext Context { get; set; }

 

        private bool _multipleActions;

        public bool MultipleActions

        {

            get { return _multipleActions; }

            set

            {

                _multipleActions = value;

                if (value)

                    _undo.Insert(0, new List<UndoRedoAction>());

            }

        }

 

        public void ActivateUndoRedoTracking(int undoStackLength)

        {

            _undoStackLength = undoStackLength;

            _undo = new List<List<UndoRedoAction>>(undoStackLength);

            _redo = new List<List<UndoRedoAction>>(undoStackLength);

            _trackChanges = true;

 

            var objectStateEntries = Context.ObjectStateManager.GetObjectStateEntries(EntityState.Added | EntityState.Deleted | EntityState.Modified | EntityState.Unchanged).ToList();

 

            PropertyChangingEventHandler entityModifing = null;

            entityModifing = (sender, e) =>

            {

                var propInfo = sender.GetType().GetProperty(e.PropertyName, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);

                var value = propInfo.GetValue(sender, null);

                var undoRedoAction = new UndoRedoAction { EntityState = EntityState.Modified, UndoAction = () => propInfo.SetValue(sender, value, null) };

                var inpc = sender as INotifyPropertyChanged;

                if (inpc != null)

                {

                    PropertyChangedEventHandler entityModified = null;

                    entityModified = (s, e2) =>

                        {

                            if (e.PropertyName == e2.PropertyName && _trackChanges)

                            {

                                var newValue = propInfo.GetValue(sender, null);

                                undoRedoAction.RedoAction = () => propInfo.SetValue(sender, newValue, null);

                                if (MultipleActions)

                                    _undo[0].Add(undoRedoAction);

                                else

                                    _undo.Insert(0, new List<UndoRedoAction>() { undoRedoAction });

                                if (_undo.Count > _undoStackLength)

                                    _undo.RemoveAt(_undoStackLength);

                                _redo.Clear();

                            }

                            inpc.PropertyChanged -= entityModified;

                        };

                    inpc.PropertyChanged += entityModified;

                }

            };

            foreach (var e in objectStateEntries.Select(ose => ose.Entity as INotifyPropertyChanging).Where(inpc => inpc != null))

                e.PropertyChanging += entityModifing;

 

            Context.ObjectStateManager.ObjectStateManagerChanged += (sender, e) =>

            {

                switch (e.Action)

                {

                    case CollectionChangeAction.Add:

                        var inpc = e.Element as INotifyPropertyChanging;

                        if (inpc != null)

                            inpc.PropertyChanging += entityModifing;

                        break;

                }

            };

        }

 

        public bool CanUndo

        {

            get { return _undo != null && _undo.Any(); }

        }

 

        public void Undo()

        {

            if (!CanUndo)

                throw new InvalidOperationException();

            var undoRedoAction = _undo.First();

            _undo.RemoveAt(0);

            _trackChanges = false;

            foreach (var undoAction in undoRedoAction)

                undoAction.UndoAction();

            _trackChanges = true;

            _redo.Insert(0, undoRedoAction);

        }

 

        public bool CanRedo

        {

            get { return _redo != null && _redo.Any(); }

        }

 

        public void Redo()

        {

            if (!CanRedo)

                throw new InvalidOperationException();

            var undoRedoAction = _redo.First();

            _redo.RemoveAt(0);

            _trackChanges = false;

            foreach (var redoAction in undoRedoAction)

                redoAction.RedoAction();

            _trackChanges = true;

            _undo.Insert(0, undoRedoAction);

        }

    }

 

    private class UndoRedoAction

    {

        public EntityState EntityState { get; set; }

        public Action UndoAction { get; set; }

        public Action RedoAction { get; set; }

    }

}

And to make a demo about it, I did an unit test:

[TestClass]

public class ObjectContextExtensionTest

{

    [TestMethod]

    public void Test()

    {

        using (var context = new NorthwindEntities())

        {

            var c = context.Categories.First();

            context.ActivateUndoRedoTracking(5);

            var cOriginalCategoryName = c.CategoryName;

            c.CategoryName = "CN";

            var c2 = context.Categories.OrderBy(c3 => c3.CategoryID).Skip(1).First();

            var c2OriginalCategoryName = c2.CategoryName;

            c2.CategoryName = "C2N";

            c.CategoryName = "CN2";

            context.Undo();

            Assert.AreEqual("CN", c.CategoryName);

            Assert.AreEqual("C2N", c2.CategoryName);

            context.Undo();

            Assert.AreEqual("CN", c.CategoryName);

            Assert.AreEqual(c2OriginalCategoryName, c2.CategoryName);

            context.Undo();

            Assert.AreEqual(cOriginalCategoryName, c.CategoryName);

            Assert.AreEqual(c2OriginalCategoryName, c2.CategoryName);

            context.Redo();

            Assert.AreEqual("CN", c.CategoryName);

            Assert.AreEqual(c2OriginalCategoryName, c2.CategoryName);

            context.Redo();

            Assert.AreEqual("CN", c.CategoryName);

            Assert.AreEqual("C2N", c2.CategoryName);

            context.Redo();

            Assert.AreEqual("CN2", c.CategoryName);

            Assert.AreEqual("C2N", c2.CategoryName);

 

            context.BeginGroupOfUndoActions();

            c.CategoryName = "CN3";

            c2.CategoryName = "C2N2";

            context.BeginGroupOfUndoActions();

            c.CategoryName = "CN4";

            c2.CategoryName = "C2N3";

            context.EndGroupOfUndoActions();

            c.CategoryName = "CN5";

            c2.CategoryName = "C2N4";

            context.Undo();

            Assert.AreEqual("CN5", c.CategoryName);

            Assert.AreEqual("C2N3", c2.CategoryName);

            context.Undo();

            Assert.AreEqual("CN4", c.CategoryName);

            Assert.AreEqual("C2N3", c2.CategoryName);

            context.Undo();

            Assert.AreEqual("CN3", c.CategoryName);

            Assert.AreEqual("C2N2", c2.CategoryName);

            context.Undo();

            Assert.AreEqual("CN2", c.CategoryName);

            Assert.AreEqual("C2N", c2.CategoryName);

            context.Redo();

            Assert.AreEqual("CN3", c.CategoryName);

            Assert.AreEqual("C2N2", c2.CategoryName);

            context.Redo();

            Assert.AreEqual("CN4", c.CategoryName);

            Assert.AreEqual("C2N3", c2.CategoryName);

            context.Redo();

            Assert.AreEqual("CN5", c.CategoryName);

            Assert.AreEqual("C2N3", c2.CategoryName);

            context.Redo();

            Assert.AreEqual("CN5", c.CategoryName);

            Assert.AreEqual("C2N4", c2.CategoryName);

        }

    }

}

 

 

EF : Undo Redo

How to use Undo/Redo with EntityFramework? This isn’t managed by EF.

So we will have to do it ourselves.

We could clone the entities and group them into a Stack but in this case I worry about my memory usage growing too much.

So I prefer another solution with Action.

Note that for this POC, I just manage Undo/Redo on scalar properties.

public static class ObjectContextExtension

{

    private static Dictionary<ObjectContext, ObjectContextUndoRedo> _objectContextUndoRedo = new Dictionary<ObjectContext, ObjectContextUndoRedo>();

 

    public static void ActivateUndoRedoTracking(this ObjectContext context, int undoStackLength)

    {

        ObjectContextUndoRedo objectContextUndoRedo;

        if (_objectContextUndoRedo.ContainsKey(context))

            objectContextUndoRedo = _objectContextUndoRedo[context];

        else

        {

            objectContextUndoRedo = new ObjectContextUndoRedo { Context = context };

            _objectContextUndoRedo.Add(context, objectContextUndoRedo);

        }

        objectContextUndoRedo.ActivateUndoRedoTracking(undoStackLength);

    }

 

    public static bool CanUndo(this ObjectContext context)

    {

        if (!_objectContextUndoRedo.ContainsKey(context))

            throw new InvalidOperationException();

        return _objectContextUndoRedo[context].CanUndo;

    }

 

    public static void Undo(this ObjectContext context)

    {

        if (!_objectContextUndoRedo.ContainsKey(context))

            throw new InvalidOperationException();

        _objectContextUndoRedo[context].Undo();

    }

 

    public static bool CanRedo(this ObjectContext context)

    {

        if (!_objectContextUndoRedo.ContainsKey(context))

            throw new InvalidOperationException();

        return _objectContextUndoRedo[context].CanRedo;

    }

 

    public static void Redo(this ObjectContext context)

    {

        if (!_objectContextUndoRedo.ContainsKey(context))

            throw new InvalidOperationException();

        _objectContextUndoRedo[context].Redo();

    }

 

    private class ObjectContextUndoRedo

    {

        private List<UndoRedoAction> _undo, _redo;

        private int _undoStackLength;

        private bool _trackChanges;

 

        public ObjectContext Context { get; set; }

 

        public void ActivateUndoRedoTracking(int undoStackLength)

        {

            _undoStackLength = undoStackLength;

            _undo = new List<UndoRedoAction>(undoStackLength);

            _redo = new List<UndoRedoAction>(undoStackLength);

            _trackChanges = true;

 

            var objectStateEntries = Context.ObjectStateManager.GetObjectStateEntries(EntityState.Added | EntityState.Deleted | EntityState.Modified | EntityState.Unchanged).ToList();

 

            PropertyChangingEventHandler entityModifing = null;

            entityModifing = (sender, e) =>

            {

                var propInfo = sender.GetType().GetProperty(e.PropertyName);

                var value = propInfo.GetValue(sender, null);

                var undoRedoAction = new UndoRedoAction { EntityState = EntityState.Modified, UndoAction = () => propInfo.SetValue(sender, value, null) };

                var inpc = sender as INotifyPropertyChanged;

                if (inpc != null)

                {

                    PropertyChangedEventHandler entityModified = null;

                    entityModified = (s, e2) =>

                        {

                            if (e.PropertyName == e2.PropertyName && _trackChanges)

                            {

                                var newValue = propInfo.GetValue(sender, null);

                                undoRedoAction.RedoAction = () => propInfo.SetValue(sender, newValue, null);

                                _undo.Insert(0, undoRedoAction);

                                if (_undo.Count > _undoStackLength)

                                    _undo.RemoveAt(_undoStackLength);

                                _redo.Clear();

                            }

                            inpc.PropertyChanged -= entityModified;

                        };

                    inpc.PropertyChanged += entityModified;

                }

            };

            foreach (var e in objectStateEntries.Select(ose => ose.Entity as INotifyPropertyChanging).Where(inpc => inpc != null))

                e.PropertyChanging += entityModifing;

 

            Context.ObjectStateManager.ObjectStateManagerChanged += (sender, e) =>

            {

                switch (e.Action)

                {

                    case CollectionChangeAction.Add:

                        var inpc = e.Element as INotifyPropertyChanging;

                        if (inpc != null)

                            inpc.PropertyChanging += entityModifing;

                        break;

                }

            };

        }

 

        public bool CanUndo

        {

            get { return _undo != null && _undo.Any(); }

        }

 

        public void Undo()

        {

            if (!CanUndo)

                throw new InvalidOperationException();

            var undoRedoAction = _undo.First();

            _undo.RemoveAt(0);

            _trackChanges = false;

            undoRedoAction.UndoAction();

            _trackChanges = true;

            _redo.Insert(0, undoRedoAction);

        }

 

        public bool CanRedo

        {

            get { return _redo != null && _redo.Any(); }

        }

 

        public void Redo()

        {

            if (!CanRedo)

                throw new InvalidOperationException();

            var undoRedoAction = _redo.First();

            _redo.RemoveAt(0);

            _trackChanges = false;

            undoRedoAction.RedoAction();

            _trackChanges = true;

            _undo.Insert(0, undoRedoAction);

        }

    }

 

    private class UndoRedoAction

    {

        public EntityState EntityState { get; set; }

        public Action UndoAction { get; set; }

        public Action RedoAction { get; set; }

    }

}

 The main question in my opinion is: does a complete unod/redo make sense? Indeed, we will have some big issues with cascade, Identity, concurrent access...

How to split your EDM v2?

After my previous post about it, my customer asks me the following question: how to get a complete graph (with categories, suppliers, products, order details, orders, customers and employees)?

To realize it, we have to add “sort of” navigation properties for the entities like this:

private Supplier _supplier;

/// <remarks>

/// Changes aren't saved

/// </remarks>

public Supplier Supplier

{

    get

    {

        if (_supplier == null)

            _supplier = this.GetSupplier();

        return _supplier;

    }

    set { _supplier = value; }

}

 

private IEnumerable<Orders.OrderDetail> _orderDetails;

/// <remarks>

/// Changes aren't saved

/// </remarks>

public IEnumerable<Orders.OrderDetail> OrderDetails

{

    get

    {

        if (_orderDetails == null)

            _orderDetails = this.GetOrderDetails();

        return _orderDetails;

    }

    set { _orderDetails = value; }

}

Then, the idea is to set these informations at first.

For this, we can do the following:

using (var stockContext = new StocksEntities())

{

    var categories = stockContext.Categories.Include("Products").ToList();

    foreach (var p in stockContext.ObjectStateManager.GetObjectStateEntries(EntityState.Unchanged).Select(e => e.Entity).OfType<Entities.Stocks.Product>())

    {

        using (var supplierContext = new SuppliersEntities())

        {

            p.Supplier = (from p2 in supplierContext.Products

                          where p2.ProductID == p.ProductID

                          select p2.Supplier).FirstOrDefault();

        }

        using (var orderContext = new OrdersEntities())

        {

            p.OrderDetails = (from od in orderContext.OrderDetails.Include("Order.Customer")

                              where od.ProductID == p.ProductID

                              select od).ToList();

            foreach (var o in orderContext.ObjectStateManager.GetObjectStateEntries(EntityState.Unchanged).Select(e => e.Entity).OfType<Entities.Orders.Order>())

            {

                using (var employeeContext = new EmployeesEntities())

                {

                    o.Employee = (from oe in employeeContext.Orders

                                  where oe.OrderID == o.OrderID

                                  select oe.Employee).FirstOrDefault();

                }

            }

        }

    }

}

However, this way can generate a lot of SQL queries and so the execution isn’t very fast. If we want to get only part of the categories (with all their graphs), we can use this code:

using (var stockContext = new StocksEntities())

{

    using (var supplierContext = new SuppliersEntities())

    {

        using (var orderContext = new OrdersEntities())

        {

            using (var employeeContext = new EmployeesEntities())

            {

                var categories = stockContext.Categories.Include("Products").ToList();

                var products = categories.SelectMany(c => c.Products);

                var suppliers = supplierContext.Suppliers.Where(BuildContainsExpression<Supplier, int>(s => s.SupplierID, (from p in products

               where p.SupplierID.HasValue

               select p.SupplierID.Value).Distinct())).ToList();

                var orderDetails = orderContext.OrderDetails.Include("Order.Customer").Where(BuildContainsExpression<Entities.Orders.OrderDetail, int>(od => od.ProductID, products.Select(p => p.ProductID))).ToList();

                var orders = orderDetails.Select(od => od.Order);

                var employees = employeeContext.Employees.Where(BuildContainsExpression<Employee, int>(e => e.

EmployeeID, (from o in orders

             where o.EmployeeID.HasValue

             select o.EmployeeID.Value).Distinct())).ToList();

                foreach (var p in products)

                {

                    p.Supplier = suppliers.FirstOrDefault(s => s.SupplierID == p.SupplierID);

                    p.OrderDetails = (from od in orderDetails

                                      where od.ProductID == p.ProductID

                                      select od);

                    foreach (var o in orders)

                        o.Employee = employees.FirstOrDefault(e => o.EmployeeID == e.EmployeeID);

                }

            }

        }

    }

}

If we want to get all entities, we can directly use the following code:

using (var stockContext = new StocksEntities())

{

    using (var supplierContext = new SuppliersEntities())

    {

        using (var orderContext = new OrdersEntities())

        {

            using (var employeeContext = new EmployeesEntities())

            {

                var categories = stockContext.Categories.Include("Products").ToList();

                var products = categories.SelectMany(c => c.Products);

                var suppliers = supplierContext.Suppliers.ToList();

                var orderDetails = orderContext.OrderDetails.Include("Order.Customer").ToList();

                var orders = orderDetails.Select(od => od.Order);

                var employees = employeeContext.Employees.ToList();

                foreach (var p in products)

                {

                    p.Supplier = suppliers.FirstOrDefault(s => s.SupplierID == p.SupplierID);

                    p.OrderDetails = (from od in orderDetails

                                      where od.ProductID == p.ProductID

                                      select od);

                    foreach (var o in orders)

                        o.Employee = employees.FirstOrDefault(e => o.EmployeeID == e.EmployeeID);

                }

            }

        }

    }

}