public static class ObjectContextExtension
{
private static Dictionary<ObjectContext, ObjectContextUndoRedo> _objectContextUndoRedo = new Dictionary<ObjectContext, ObjectContextUndoRedo>();
public static void ActivateUndoRedoTracking(this ObjectContext context, int undoStackLength)
{
ObjectContextUndoRedo objectContextUndoRedo;
_objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);
if (objectContextUndoRedo == null)
{
objectContextUndoRedo = new ObjectContextUndoRedo { Context = context };
_objectContextUndoRedo.Add(context, objectContextUndoRedo);
}
objectContextUndoRedo.ActivateUndoRedoTracking(undoStackLength);
}
public static bool CanUndo(this ObjectContext context)
{
ObjectContextUndoRedo objectContextUndoRedo;
_objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);
if (objectContextUndoRedo == null)
throw new InvalidOperationException();
return objectContextUndoRedo.CanUndo;
}
public static void Undo(this ObjectContext context)
{
ObjectContextUndoRedo objectContextUndoRedo;
_objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);
if (objectContextUndoRedo == null)
throw new InvalidOperationException();
objectContextUndoRedo.Undo();
}
public static bool CanRedo(this ObjectContext context)
{
ObjectContextUndoRedo objectContextUndoRedo;
_objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);
if (objectContextUndoRedo == null)
throw new InvalidOperationException();
return objectContextUndoRedo.CanRedo;
}
public static void Redo(this ObjectContext context)
{
ObjectContextUndoRedo objectContextUndoRedo;
_objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);
if (objectContextUndoRedo == null)
throw new InvalidOperationException();
objectContextUndoRedo.Redo();
}
public static void BeginGroupOfUndoActions(this ObjectContext context)
{
ObjectContextUndoRedo objectContextUndoRedo;
_objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);
if (objectContextUndoRedo == null)
throw new InvalidOperationException();
objectContextUndoRedo.MultipleActions = true;
}
public static void EndGroupOfUndoActions(this ObjectContext context)
{
ObjectContextUndoRedo objectContextUndoRedo;
_objectContextUndoRedo.TryGetValue(context, out objectContextUndoRedo);
if (objectContextUndoRedo == null)
throw new InvalidOperationException();
objectContextUndoRedo.MultipleActions = false;
}
private class ObjectContextUndoRedo
{
private List<List<UndoRedoAction>> _undo, _redo;
private int _undoStackLength;
private bool _trackChanges;
public ObjectContext Context { get; set; }
private bool _multipleActions;
public bool MultipleActions
{
get { return _multipleActions; }
set
{
_multipleActions = value;
if (value)
_undo.Insert(0, new List<UndoRedoAction>());
}
}
public void ActivateUndoRedoTracking(int undoStackLength)
{
_undoStackLength = undoStackLength;
_undo = new List<List<UndoRedoAction>>(undoStackLength);
_redo = new List<List<UndoRedoAction>>(undoStackLength);
_trackChanges = true;
var objectStateEntries = Context.ObjectStateManager.GetObjectStateEntries(EntityState.Added | EntityState.Deleted | EntityState.Modified | EntityState.Unchanged).ToList();
PropertyChangingEventHandler entityModifing = null;
entityModifing = (sender, e) =>
{
var propInfo = sender.GetType().GetProperty(e.PropertyName, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
var value = propInfo.GetValue(sender, null);
var undoRedoAction = new UndoRedoAction { EntityState = EntityState.Modified, UndoAction = () => propInfo.SetValue(sender, value, null) };
var inpc = sender as INotifyPropertyChanged;
if (inpc != null)
{
PropertyChangedEventHandler entityModified = null;
entityModified = (s, e2) =>
{
if (e.PropertyName == e2.PropertyName && _trackChanges)
{
var newValue = propInfo.GetValue(sender, null);
undoRedoAction.RedoAction = () => propInfo.SetValue(sender, newValue, null);
if (MultipleActions)
_undo[0].Add(undoRedoAction);
else
_undo.Insert(0, new List<UndoRedoAction>() { undoRedoAction });
if (_undo.Count > _undoStackLength)
_undo.RemoveAt(_undoStackLength);
_redo.Clear();
}
inpc.PropertyChanged -= entityModified;
};
inpc.PropertyChanged += entityModified;
}
};
foreach (var e in objectStateEntries.Select(ose => ose.Entity as INotifyPropertyChanging).Where(inpc => inpc != null))
e.PropertyChanging += entityModifing;
Context.ObjectStateManager.ObjectStateManagerChanged += (sender, e) =>
{
switch (e.Action)
{
case CollectionChangeAction.Add:
var inpc = e.Element as INotifyPropertyChanging;
if (inpc != null)
inpc.PropertyChanging += entityModifing;
break;
}
};
}
public bool CanUndo
{
get { return _undo != null && _undo.Any(); }
}
public void Undo()
{
if (!CanUndo)
throw new InvalidOperationException();
var undoRedoAction = _undo.First();
_undo.RemoveAt(0);
_trackChanges = false;
foreach (var undoAction in undoRedoAction)
undoAction.UndoAction();
_trackChanges = true;
_redo.Insert(0, undoRedoAction);
}
public bool CanRedo
{
get { return _redo != null && _redo.Any(); }
}
public void Redo()
{
if (!CanRedo)
throw new InvalidOperationException();
var undoRedoAction = _redo.First();
_redo.RemoveAt(0);
_trackChanges = false;
foreach (var redoAction in undoRedoAction)
redoAction.RedoAction();
_trackChanges = true;
_undo.Insert(0, undoRedoAction);
}
}
private class UndoRedoAction
{
public EntityState EntityState { get; set; }
public Action UndoAction { get; set; }
public Action RedoAction { get; set; }
}
}