public class SampleEdmxCodeGenerator : BaseCodeGeneratorWithSite
{
private EntityContainer _objectContext;
private Dictionary<string, string> _entitySetNames;
private Dictionary<string, List<string>> _typesHierarchyToAddInObjectContext;
public Dictionary<string, string> EntitySetNames
{
get
{
if (_entitySetNames == null)
_entitySetNames = new Dictionary<string, string>();
return _entitySetNames;
}
}
private Dictionary<string, List<string>> TypesHierarchyToAddInObjectContext
{
get
{
if (_typesHierarchyToAddInObjectContext == null)
_typesHierarchyToAddInObjectContext = new Dictionary<string, List<string>>();
return _typesHierarchyToAddInObjectContext;
}
}
private IEnumerable<string> GetSubEntitiesOf(EntityType baseType)
{
var empty = Enumerable.Empty<string>();
if (_typesHierarchyToAddInObjectContext == null)
return empty;
string baseTypeName = baseType.Name;
if (!_typesHierarchyToAddInObjectContext.ContainsKey(baseTypeName))
return empty;
return TypesHierarchyToAddInObjectContext[baseTypeName];
}
private void AddSubEntities(string baseType, string typeToAdd)
{
string baseTypeName = baseType;
if (!TypesHierarchyToAddInObjectContext.ContainsKey(baseTypeName))
TypesHierarchyToAddInObjectContext.Add(baseTypeName, new List<string>());
TypesHierarchyToAddInObjectContext[baseTypeName].Add(typeToAdd);
}
protected override string GetDefaultExtension()
{
return (".Designer" + base.GetDefaultExtension());
}
protected override byte[] GenerateCode(string inputFileContent)
{
byte[] generatedCodeAsBytes = null;
try
{
XElement csdlContent = ExtractCsdlContent(inputFileContent);
if (csdlContent == null)
{
throw new InvalidOperationException("No CSDL content in input file");
}
_objectContext = null;
_entitySetNames = null;
_typesHierarchyToAddInObjectContext = null;
LanguageOption languageOption = LanguageOption.GenerateCSharpCode;
string fileExtension = base.GetCodeProvider().FileExtension;
if (fileExtension != null && fileExtension.Length > 0)
{
fileExtension = "." + fileExtension.TrimStart(".".ToCharArray());
}
if (fileExtension.EndsWith(".vb", StringComparison.InvariantCultureIgnoreCase))
{
languageOption = LanguageOption.GenerateVBCode;
}
else if (fileExtension.EndsWith(".cs", StringComparison.InvariantCultureIgnoreCase))
{
languageOption = LanguageOption.GenerateCSharpCode;
}
else
{
throw new InvalidOperationException("Unsupported project language. Only C# and VB are supported.");
}
if (base.CodeGeneratorProgress != null)
{
base.CodeGeneratorProgress.Progress(33, 100);
}
EntityClassGenerator classGenerator;
IList<EdmSchemaError> errors = null;
using (StringWriter codeWriter = new StringWriter(CultureInfo.InvariantCulture))
{
using (XmlReader csdlReader = csdlContent.CreateReader())
{
classGenerator = new EntityClassGenerator(languageOption);
classGenerator.OnTypeGenerated += new TypeGeneratedEventHandler(OnTypeGenerated);
classGenerator.OnPropertyGenerated += new PropertyGeneratedEventHandler(OnPropertyGenerated);
foreach (var entityType in from et in csdlContent.Descendants("{http://schemas.microsoft.com/ado/2006/04/edm}EntityType")
let etBaseType = et.Attributes("BaseType").FirstOrDefault()
where etBaseType != null
select new { EntityType = et.Attribute("Name").Value, BaseType = etBaseType.Value })
{
var baseType = entityType.BaseType;
string baseTypeTmp;
while ((baseTypeTmp = (from et in csdlContent.Descendants("{http://schemas.microsoft.com/ado/2006/04/edm}EntityType")
where et.Attribute("Name").Value == baseType
let etBaseType = et.Attributes("BaseType").FirstOrDefault()
where etBaseType != null
select etBaseType.Value).FirstOrDefault()) != null)
baseTypeTmp = baseType;
Func<string, string> getSimpleName = entityTypeName => entityTypeName.Substring(entityTypeName.IndexOf(".") + 1);
AddSubEntities(getSimpleName(baseType), getSimpleName(entityType.EntityType));
}
errors = classGenerator.GenerateCode(csdlReader, codeWriter);
}
if (base.CodeGeneratorProgress != null)
{
base.CodeGeneratorProgress.Progress(66, 100);
}
if (errors != null)
{
foreach (EdmSchemaError error in errors)
{
int line = (error.Line == 0) ? 0 : (error.Line - 1);
int column = (error.Column == 0) ? 0 : (error.Column - 1);
if (error.Severity == EdmSchemaErrorSeverity.Warning)
{
base.GeneratorWarning(0, error.Message, (uint)line, (uint)column);
}
else
{
base.GeneratorError(4, error.Message, (uint)line, (uint)column);
}
}
}
generatedCodeAsBytes = Encoding.UTF8.GetBytes(codeWriter.ToString());
}
if (base.CodeGeneratorProgress != null)
{
base.CodeGeneratorProgress.Progress(100, 100);
}
}
catch (Exception e)
{
base.GeneratorError(4, e.Message, 1, 1);
generatedCodeAsBytes = null;
}
return generatedCodeAsBytes;
}
private void OnTypeGenerated(object sender, TypeGeneratedEventArgs eventArgs)
{
eventArgs.AdditionalAttributes.AddRange(CreateCodeAttributes(eventArgs.TypeSource));
var objectContext = eventArgs.TypeSource as EntityContainer;
if (objectContext != null)
{
_objectContext = objectContext;
var baseEntitySets = _objectContext.MetadataProperties.FirstOrDefault(mp => mp.Name == "BaseEntitySets");
if (baseEntitySets != null)
{
foreach (var entitySet in (ReadOnlyMetadataCollection<EntitySetBase>)baseEntitySets.Value)
{
var derivedBaseEntityType = entitySet.ElementType as EntityType;
if (derivedBaseEntityType != null)
{
EntitySetNames.Add(derivedBaseEntityType.Name, entitySet.Name);
if (_typesHierarchyToAddInObjectContext != null)
foreach (var derivedEntityTypeName in GetSubEntitiesOf(derivedBaseEntityType))
{
var newProp = new CodeMemberProperty { Name = derivedEntityTypeName + "s", Attributes = MemberAttributes.Public | MemberAttributes.Final, Type = new CodeTypeReference("global::System.Linq.IQueryable<" + derivedEntityTypeName + ">") };
newProp.GetStatements.Add(new CodeMethodReturnStatement(new CodeMethodInvokeExpression(new CodeMethodReferenceExpression(new CodePropertyReferenceExpression(new CodeThisReferenceExpression(), _entitySetNames[derivedBaseEntityType.Name]), "OfType", new CodeTypeReference(derivedEntityTypeName)))));
eventArgs.AdditionalMembers.Add(newProp);
}
}
}
}
}
}
private void OnPropertyGenerated(object sender, PropertyGeneratedEventArgs eventArgs)
{
eventArgs.AdditionalAttributes.AddRange(CreateCodeAttributes(eventArgs.PropertySource));
}
private IList<CodeAttributeDeclaration> CreateCodeAttributes(MetadataItem item)
{
string xmlns = "http://tempuri.org/AttributeAnnotations";
List<CodeAttributeDeclaration> codeAttributeDeclarations = new List<CodeAttributeDeclaration>();
if (item != null)
{
IEnumerable<MetadataProperty> metadataProperties = item.MetadataProperties.Where(prop => prop.Name.StartsWith(xmlns));
foreach (MetadataProperty metadataProperty in metadataProperties)
{
string metadataPropertyValue = (string)metadataProperty.Value;
if (!String.IsNullOrEmpty(metadataPropertyValue))
{
string[] attributes = metadataPropertyValue.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries);
foreach (string attribute in attributes)
{
string attributeName = attribute;
string[] attributeParams = new string[1];
if (attribute.Contains('('))
{
attributeParams = attribute.Split(new char[] { '(', ')' }, StringSplitOptions.RemoveEmptyEntries);
attributeName = attributeParams[0];
}
CodeAttributeDeclaration codeAttributeDeclaration = new CodeAttributeDeclaration(attributeName);
foreach (string attributeParam in attributeParams.Skip(1))
{
object attributeParamObj = null;
bool attributeParamBool = false;
if (bool.TryParse(attributeParam, out attributeParamBool))
{
attributeParamObj = attributeParamBool;
}
else
{
attributeParamObj = attributeParam;
}
codeAttributeDeclaration.Arguments.Add(new CodeAttributeArgument(new CodePrimitiveExpression(attributeParamObj)));
}
codeAttributeDeclarations.Add(codeAttributeDeclaration);
}
}
}
}
return codeAttributeDeclarations;
}
private XElement ExtractCsdlContent(string inputFileContent)
{
XElement csdlContent = null;
XNamespace edmxns = "http://schemas.microsoft.com/ado/2007/06/edmx";
XNamespace edmns = "http://schemas.microsoft.com/ado/2006/04/edm";
XDocument edmxDoc = XDocument.Load(new StringReader(inputFileContent));
if (edmxDoc != null)
{
XElement edmxNode = edmxDoc.Element(edmxns + "Edmx");
if (edmxNode != null)
{
XElement runtimeNode = edmxNode.Element(edmxns + "Runtime");
if (runtimeNode != null)
{
XElement conceptualModelsNode = runtimeNode.Element(edmxns + "ConceptualModels");
if (conceptualModelsNode != null)
{
csdlContent = conceptualModelsNode.Element(edmns + "Schema");
}
}
}
}
return csdlContent;
}
}