IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    T4模板配合Dapper生成Model层,包括Insert、Update、Delete、GetList方法

    summer发表于 2016-11-10 02:12:42
    love 0

    以前见朋友开发程序用的是个C#版的Model生成器,里面就有Insert、Update、Delete、GetList这些方法,感觉挺方便,最近利用T4模板和Dapper自己也写了个。好东西要分享,顺便网友们给看看还有什么优化的地方。如果有什么好建议或者是不懂的地方可以联系我 QQ:1229145381  加好友请写备注:CSDN

    一共是三个文件,添加的时候直接添加文本文件就行,然后把后缀名改成与图片对应的就行。Model是主文件,DbHelper.ttinclude和MultipleOutputHelper.ttinclude就像是html页和Jq库一样它俩都在Model.tt中引用。

    下面上代码:

    Model.tt

    <#@ template debug="true" hostspecific="true" language="C#" #>
    <#@ assembly name="System.Data" #>
    <#@ assembly name="System.xml" #>
    <#@ import namespace="System.Collections.Generic" #>
    <#@ import namespace="System.Data.SqlClient" #>
    <#@ import namespace="System.Data" #>
    <#@ assembly name="System.Core" #>
    <#@ import namespace="System.Linq" #>

    <#@ include file="MultipleOutputHelper.ttinclude" #>

    <#

    string connectionString= "server=.;database=RoleSystem;uid=sa;password=123456;";
    SqlConnection conn = new SqlConnection(connectionString);
    conn.Open();

    string selectQuery ="SET FMTONLY ON; select * from @tableName; SET FMTONLY OFF;";
    SqlCommand command = new SqlCommand(selectQuery,conn);
    SqlDataAdapter ad = new SqlDataAdapter(command);
    System.Data.DataSet ds = new DataSet();

    var manager = Manager.Create(Host, GenerationEnvironment);

    System.Data.DataTable schema = conn.GetSchema("Tables");
    foreach(System.Data.DataRow row in schema.Rows)
    {
    ds.Tables.Clear();
    string tb_name= row["TABLE_NAME"].ToString();
    command.CommandText = selectQuery.Replace("@tableName",row["TABLE_NAME"].ToString());
    ad.FillSchema(ds, SchemaType.Mapped,tb_name);

    manager.StartNewFile(tb_name+".cs");#>

    using System;
    using System.Data;
    using System.Linq;
    using Dapper;
    using System.Data.SqlClient;
    using System.Configuration;
    using System.Collections.Generic;

    namespace RoleSystem.Model
    {
    /// <summary>
    /// 实体-<#=tb_name#>
    /// </summary>
    public partial class <#=tb_name#>
    {
    <#
    PushIndent(" ");

    WriteLine("public static SqlConnection GetSqlConnection(){ return new SqlConnection(ConfigurationManager.ConnectionStrings[\"connstr\"].ConnectionString);}");

    foreach (DataColumn dc in ds.Tables[0].Columns)
    {
    //根据表名称和字段名称查出字段描述 comment
    string sql = @"select d.value from sys.syscolumns a left outer join sys.extended_properties d on a.id = d.major_id and a.colid = d.minor_id and d.name = 'MS_Description' where object_name(a.id) = '{0}' and a.name='{1}' ";
    SqlCommand cmd = new SqlCommand(string.Format(sql, row["TABLE_NAME"],dc.ColumnName),conn);
    object comment= cmd.ExecuteScalar();
    WriteLine("\r\n /// <summary>\r\n ///"+comment+"\r\n /// </summary>\r\n public " + dc.DataType.Name+ (dc.AllowDBNull && dc.DataType.Name.ToLower() != "string" ? "? ": " ") + dc.ColumnName + " { get; set; }");
    }

    //查询一个表的信息 当前表的所有字段,字段中没有查自增健
    string files = "select a.name, a.isnullable, a.iscomputed,case when d.value is null then a.name else d.value end as comment, type_name(a.xusertype) as type, a.length, case when (a.status & 0x80) = 0x80 then 1 else 0 end as is_identity, case when index_col (OBJECT_NAME(a.id), b.indid, a.colid) = a.name then 1 else 0 end as is_primary, case when ( TYPE_NAME(a.xusertype) IN ('xml','text','ntext','image') OR (TYPE_NAME(a.xusertype) IN ('varchar','nvarchar','varbinary') AND a.length = -1) ) then 1 else 0 end as is_lob, case when ( c.ctext is null ) then 'null' else substring(c.text, 2, len(c.text) - 2 ) end as defaultValue from sys.syscolumns a left outer join sys.extended_properties d on a.id = d.major_id and a.colid = d.minor_id and d.name = 'MS_Description' left outer join sys.sysindexes b on a.id = b.id and ( b.status & 0x800) = 0x800 left outer join sys.syscomments c on a.cdefault = c.id where object_name(a.id) = '"+row["TABLE_NAME"].ToString()+"' and a.status<>128 order by a.colorder";
    SqlCommand cmd1 = new SqlCommand(files, conn);
    SqlDataAdapter adFiles = new SqlDataAdapter(cmd1);
    DataSet dsFiles = new DataSet();
    adFiles.Fill(dsFiles);

    if (dsFiles.Tables.Count > 0)
    {
    string insert = "string sql=\"insert into {0} ({1}) values({2}) SELECT @@IDENTITY\";";
    string update = "string sql=\"update set {0} {1} where id=@id\";";
    string delete = "string sql=\"delete from "+row["TABLE_NAME"].ToString()+" where ID=@ID\";";
    string k1 = "";
    string v1 = "";
    string v2 = "";
    string t="";
    foreach (DataRow dr in dsFiles.Tables[0].Rows)
    {

    k1 += k1 == "" ? dr["name"].ToString() + "," : dr["name"].ToString();
    v1 += v1 == "" ? "@" + dr["name"].ToString() + "," : "@" + dr["name"].ToString();
    v2 +=dr["name"].ToString() + "=@" + dr["name"].ToString() + ",";
    }
    WriteLine(t);
    insert = string.Format(insert, row["TABLE_NAME"].ToString(), k1, v1);
    update = string.Format(update, row["TABLE_NAME"].ToString(), v2);

    //insert方法

    WriteLine(" ///<summary>");
    WriteLine(" ///增加一条数据");
    WriteLine(" ///</summary>");
    WriteLine(" ///<returns></returns>");
    WriteLine(" public object Add("+row["TABLE_NAME"].ToString()+" model)\r\n {");
    WriteLine(" "+insert);
    WriteLine(" using (IDbConnection conn = GetSqlConnection())");
    WriteLine(" {");
    WriteLine(" try { return conn.ExecuteScalar(sql.ToString(), model); }");
    WriteLine(" catch (Exception) { throw; }");
    WriteLine(" }");
    WriteLine(" }");

    //update方法
    WriteLine(" ///<summary>");
    WriteLine(" ///更新一条数据");
    WriteLine(" ///</summary>");
    WriteLine(" ///<returns></returns>");
    WriteLine(" public int Update("+row["TABLE_NAME"].ToString()+" model)\r\n {");
    WriteLine(" "+update);
    WriteLine(" using (IDbConnection conn = GetSqlConnection())");
    WriteLine(" {");
    WriteLine(" try { return conn.Execute(sql.ToString(), model); }");
    WriteLine(" catch (Exception) { throw; }");
    WriteLine(" }");
    WriteLine(" }");

    //delete方法
    WriteLine(" ///<summary>");
    WriteLine(" ///删除一条数据");
    WriteLine(" ///</summary>");
    WriteLine(" ///<returns></returns>");
    WriteLine(" public int Delete(int id)\r\n {");
    WriteLine(" "+delete);
    WriteLine(" using (IDbConnection conn = GetSqlConnection())");
    WriteLine(" {");
    WriteLine(" try { return conn.Execute(sql.ToString()); }");
    WriteLine(" catch (Exception) { throw; }");
    WriteLine(" }");
    WriteLine(" }");

    //获取一个实体方法
    string list = "string sql=\"select * from "+row["TABLE_NAME"].ToString()+" where ID=@ID\";";

    WriteLine(" ///<summary>");
    WriteLine(" ///获取一个实体");
    WriteLine(" ///</summary>");
    WriteLine(" ///<returns></returns>");
    WriteLine(" public "+row["TABLE_NAME"].ToString()+" GetModel(int id)\r\n {");
    WriteLine(" "+list);
    WriteLine(" using (IDbConnection conn = GetSqlConnection())");
    WriteLine(" {");
    WriteLine(" try { return conn.Query<"+row["TABLE_NAME"].ToString()+">(sql, new { id = id }).SingleOrDefault(); }");
    WriteLine(" catch (Exception) { throw; }");
    WriteLine(" }");
    WriteLine(" }");

    //获取实体集合方法
    string modelList = "string sql=\"select * from "+row["TABLE_NAME"].ToString()+"\";";

    WriteLine(" ///<summary>");
    WriteLine(" ///获取实体集合");
    WriteLine(" ///</summary>");
    WriteLine(" ///<returns></returns>");
    WriteLine(" public List<"+row["TABLE_NAME"].ToString()+"> GetList(int top ,string where =null,string order=null)\r\n {");
    WriteLine(" "+modelList);
    WriteLine(" using (IDbConnection conn = GetSqlConnection())");
    WriteLine(" {");
    WriteLine(" try { return conn.Query<"+row["TABLE_NAME"].ToString()+">(sql).ToList(); }");
    WriteLine(" catch (Exception) { throw; }");
    WriteLine(" }");
    WriteLine(" }");

    //检测字段值是否存在
    string exist = "string sql=\"select count(*) from "+row["TABLE_NAME"].ToString()+" where {0}='{1}' \";";

    WriteLine(" ///<summary>");
    WriteLine(" ///检测字段值是否存在");
    WriteLine(" ///</summary>");
    WriteLine(" ///<returns></returns>");
    WriteLine(" public bool Exist(string key,string val)\r\n {");
    WriteLine(" "+exist);
    WriteLine(" using (IDbConnection conn = GetSqlConnection())");
    WriteLine(" {");
    WriteLine(" try { return conn.QuerySingle<int>(string.Format(sql,key,val))>0;}");
    WriteLine(" catch (Exception) { throw; }");
    WriteLine(" }");
    WriteLine(" }");
    }

    PopIndent();
    #>
    }
    }

    <#
    manager.EndBlock();
    }
    conn.Close();
    manager.Process(true);
    #>
    DbHelper.ttinclude :

    <#+
    public class DbHelper
    {
    #region GetDbTables

    public static List<DbTable> GetDbTables(string connectionString, string database, string tables = null)
    {

    if (!string.IsNullOrEmpty(tables))
    {
    tables = string.Format(" and obj.name in ('{0}')", tables.Replace(",", "','"));
    }
    #region SQL
    string sql = string.Format(@"SELECT
    obj.name tablename,
    schem.name schemname,
    idx.rows,
    CAST
    (
    CASE
    WHEN (SELECT COUNT(1) FROM sys.indexes WHERE object_id= obj.OBJECT_ID AND is_primary_key=1) >=1 THEN 1
    ELSE 0
    END
    AS BIT) HasPrimaryKey
    from {0}.sys.objects obj
    inner join {0}.dbo.sysindexes idx on obj.object_id=idx.id and idx.indid<=1
    INNER JOIN {0}.sys.schemas schem ON obj.schema_id=schem.schema_id
    where type='U' {1}
    order by obj.name", database, tables);
    #endregion
    DataTable dt = GetDataTable(connectionString, sql);
    return dt.Rows.Cast<DataRow>().Select(row => new DbTable
    {
    TableName = row.Field<string>("tablename"),
    SchemaName = row.Field<string>("schemname"),
    Rows = row.Field<int>("rows"),
    HasPrimaryKey = row.Field<bool>("HasPrimaryKey")
    }).ToList();
    }
    #endregion

    #region GetDbColumns

    public static List<DbColumn> GetDbColumns(string connectionString, string database, string tableName, string schema = "dbo")
    {
    #region SQL
    string sql = string.Format(@"
    WITH indexCTE AS
    (
    SELECT
    ic.column_id,
    ic.index_column_id,
    ic.object_id
    FROM {0}.sys.indexes idx
    INNER JOIN {0}.sys.index_columns ic ON idx.index_id = ic.index_id AND idx.object_id = ic.object_id
    WHERE idx.object_id =OBJECT_ID(@tableName) AND idx.is_primary_key=1
    )
    select
    colm.column_id ColumnID,
    CAST(CASE WHEN indexCTE.column_id IS NULL THEN 0 ELSE 1 END AS BIT) IsPrimaryKey,
    colm.name ColumnName,
    systype.name ColumnType,
    colm.is_identity IsIdentity,
    colm.is_nullable IsNullable,
    cast(colm.max_length as int) ByteLength,
    (
    case
    when systype.name='nvarchar' and colm.max_length>0 then colm.max_length/2
    when systype.name='nchar' and colm.max_length>0 then colm.max_length/2
    when systype.name='ntext' and colm.max_length>0 then colm.max_length/2
    else colm.max_length
    end
    ) CharLength,
    cast(colm.precision as int) Precision,
    cast(colm.scale as int) Scale,
    prop.value Remark
    from {0}.sys.columns colm
    inner join {0}.sys.types systype on colm.system_type_id=systype.system_type_id and colm.user_type_id=systype.user_type_id
    left join {0}.sys.extended_properties prop on colm.object_id=prop.major_id and colm.column_id=prop.minor_id
    LEFT JOIN indexCTE ON colm.column_id=indexCTE.column_id AND colm.object_id=indexCTE.object_id
    where colm.object_id=OBJECT_ID(@tableName)
    order by colm.column_id", database);
    #endregion
    SqlParameter param = new SqlParameter("@tableName", SqlDbType.NVarChar, 100) { Value = string.Format("{0}.{1}.{2}", database, schema, tableName) };
    DataTable dt = GetDataTable(connectionString, sql, param);

    return dt.Rows.Cast<DataRow>().Select(row => new DbColumn()
    {
    ColumnID = row.Field<int>("ColumnID"),
    IsPrimaryKey = row.Field<bool>("IsPrimaryKey"),
    ColumnName = row.Field<string>("ColumnName"),
    ColumnType = row.Field<string>("ColumnType"),
    IsIdentity = row.Field<bool>("IsIdentity"),
    IsNullable = row.Field<bool>("IsNullable"),
    ByteLength = row.Field<int>("ByteLength"),
    CharLength = row.Field<int>("CharLength"),
    Scale = row.Field<int>("Scale"),
    Remark = row["Remark"].ToString()
    }).ToList();
    }

    #endregion

    #region GetDataTable

    public static DataTable GetDataTable(string connectionString, string commandText, params SqlParameter[] parms)
    {
    using (SqlConnection connection = new SqlConnection(connectionString))
    {
    SqlCommand command = connection.CreateCommand();
    command.CommandText = commandText;
    command.Parameters.AddRange(parms);
    SqlDataAdapter adapter = new SqlDataAdapter(command);
    DataTable dt = new DataTable();
    adapter.Fill(dt);

    return dt;
    }
    }

    #endregion
    }

    #region DbTable
    /// <summary>
    /// 表结构
    /// </summary>
    public sealed class DbTable
    {
    /// <summary>
    /// 表名称
    /// </summary>
    public string TableName { get; set; }
    /// <summary>
    /// 表的架构
    /// </summary>
    public string SchemaName { get; set; }
    /// <summary>
    /// 表的记录数
    /// </summary>
    public int Rows { get; set; }

    /// <summary>
    /// 是否含有主键
    /// </summary>
    public bool HasPrimaryKey { get; set; }
    }
    #endregion

    #region DbColumn
    /// <summary>
    /// 表字段结构
    /// </summary>
    public sealed class DbColumn
    {
    /// <summary>
    /// 字段ID
    /// </summary>
    public int ColumnID { get; set; }

    /// <summary>
    /// 是否主键
    /// </summary>
    public bool IsPrimaryKey { get; set; }

    /// <summary>
    /// 字段名称
    /// </summary>
    public string ColumnName { get; set; }

    /// <summary>
    /// 字段类型
    /// </summary>
    public string ColumnType { get; set; }

    /// <summary>
    /// 数据库类型对应的C#类型
    /// </summary>
    public string CSharpType
    {
    get
    {
    return SqlServerDbTypeMap.MapCsharpType(ColumnType);
    }
    }

    /// <summary>
    ///
    /// </summary>
    public Type CommonType
    {
    get
    {
    return SqlServerDbTypeMap.MapCommonType(ColumnType);
    }
    }

    /// <summary>
    /// 字节长度
    /// </summary>
    public int ByteLength { get; set; }

    /// <summary>
    /// 字符长度
    /// </summary>
    public int CharLength { get; set; }

    /// <summary>
    /// 小数位
    /// </summary>
    public int Scale { get; set; }

    /// <summary>
    /// 是否自增列
    /// </summary>
    public bool IsIdentity { get; set; }

    /// <summary>
    /// 是否允许空
    /// </summary>
    public bool IsNullable { get; set; }

    /// <summary>
    /// 描述
    /// </summary>
    public string Remark { get; set; }
    }
    #endregion

    #region SqlServerDbTypeMap

    public class SqlServerDbTypeMap
    {
    public static string MapCsharpType(string dbtype)
    {
    if (string.IsNullOrEmpty(dbtype)) return dbtype;
    dbtype = dbtype.ToLower();
    string csharpType = "object";
    switch (dbtype)
    {
    case "bigint": csharpType = "long"; break;
    case "binary": csharpType = "byte[]"; break;
    case "bit": csharpType = "bool"; break;
    case "char": csharpType = "string"; break;
    case "date": csharpType = "DateTime"; break;
    case "datetime": csharpType = "DateTime"; break;
    case "datetime2": csharpType = "DateTime"; break;
    case "datetimeoffset": csharpType = "DateTimeOffset"; break;
    case "decimal": csharpType = "decimal"; break;
    case "float": csharpType = "double"; break;
    case "image": csharpType = "byte[]"; break;
    case "int": csharpType = "int"; break;
    case "money": csharpType = "decimal"; break;
    case "nchar": csharpType = "string"; break;
    case "ntext": csharpType = "string"; break;
    case "numeric": csharpType = "decimal"; break;
    case "nvarchar": csharpType = "string"; break;
    case "real": csharpType = "Single"; break;
    case "smalldatetime": csharpType = "DateTime"; break;
    case "smallint": csharpType = "short"; break;
    case "smallmoney": csharpType = "decimal"; break;
    case "sql_variant": csharpType = "object"; break;
    case "sysname": csharpType = "object"; break;
    case "text": csharpType = "string"; break;
    case "time": csharpType = "TimeSpan"; break;
    case "timestamp": csharpType = "byte[]"; break;
    case "tinyint": csharpType = "byte"; break;
    case "uniqueidentifier": csharpType = "Guid"; break;
    case "varbinary": csharpType = "byte[]"; break;
    case "varchar": csharpType = "string"; break;
    case "xml": csharpType = "string"; break;
    default: csharpType = "object"; break;
    }
    return csharpType;
    }

    public static Type MapCommonType(string dbtype)
    {
    if (string.IsNullOrEmpty(dbtype)) return Type.Missing.GetType();
    dbtype = dbtype.ToLower();
    Type commonType = typeof(object);
    switch (dbtype)
    {
    case "bigint": commonType = typeof(long); break;
    case "binary": commonType = typeof(byte[]); break;
    case "bit": commonType = typeof(bool); break;
    case "char": commonType = typeof(string); break;
    case "date": commonType = typeof(DateTime); break;
    case "datetime": commonType = typeof(DateTime); break;
    case "datetime2": commonType = typeof(DateTime); break;
    case "datetimeoffset": commonType = typeof(DateTimeOffset); break;
    case "decimal": commonType = typeof(decimal); break;
    case "float": commonType = typeof(double); break;
    case "image": commonType = typeof(byte[]); break;
    case "int": commonType = typeof(int); break;
    case "money": commonType = typeof(decimal); break;
    case "nchar": commonType = typeof(string); break;
    case "ntext": commonType = typeof(string); break;
    case "numeric": commonType = typeof(decimal); break;
    case "nvarchar": commonType = typeof(string); break;
    case "real": commonType = typeof(Single); break;
    case "smalldatetime": commonType = typeof(DateTime); break;
    case "smallint": commonType = typeof(short); break;
    case "smallmoney": commonType = typeof(decimal); break;
    case "sql_variant": commonType = typeof(object); break;
    case "sysname": commonType = typeof(object); break;
    case "text": commonType = typeof(string); break;
    case "time": commonType = typeof(TimeSpan); break;
    case "timestamp": commonType = typeof(byte[]); break;
    case "tinyint": commonType = typeof(byte); break;
    case "uniqueidentifier": commonType = typeof(Guid); break;
    case "varbinary": commonType = typeof(byte[]); break;
    case "varchar": commonType = typeof(string); break;
    case "xml": commonType = typeof(string); break;
    default: commonType = typeof(object); break;
    }
    return commonType;
    }
    }
    #endregion

    #>
    MultipleOutputHelper.ttinclude:

    <#@ assembly name="System.Core"
    #><#@ assembly name="System.Data.Linq"
    #><#@ assembly name="EnvDTE"
    #><#@ assembly name="System.Xml"
    #><#@ assembly name="System.Xml.Linq"
    #><#@ import namespace="System"
    #><#@ import namespace="System.CodeDom"
    #><#@ import namespace="System.CodeDom.Compiler"
    #><#@ import namespace="System.Collections.Generic"
    #><#@ import namespace="System.Data.Linq"
    #><#@ import namespace="System.Data.Linq.Mapping"
    #><#@ import namespace="System.IO"
    #><#@ import namespace="System.Linq"
    #><#@ import namespace="System.Reflection"
    #><#@ import namespace="System.Text"
    #><#@ import namespace="System.Xml.Linq"
    #><#@ import namespace="Microsoft.VisualStudio.TextTemplating"
    #><#+

    // Manager class records the various blocks so it can split them up
    class Manager {
    private class Block {
    public String Name;
    public int Start, Length;
    public bool IncludeInDefault;
    }

    private Block currentBlock;
    private List<Block> files = new List<Block>();
    private Block footer = new Block();
    private Block header = new Block();
    private ITextTemplatingEngineHost host;
    private StringBuilder template;
    protected List<String> generatedFileNames = new List<String>();

    public static Manager Create(ITextTemplatingEngineHost host, StringBuilder template) {
    return (host is IServiceProvider) ? new VSManager(host, template) : new Manager(host, template);
    }

    public void StartNewFile(String name) {
    if (name == null)
    throw new ArgumentNullException("name");
    CurrentBlock = new Block { Name = name };
    }

    public void StartFooter(bool includeInDefault = true) {
    CurrentBlock = footer;
    footer.IncludeInDefault = includeInDefault;
    }

    public void StartHeader(bool includeInDefault = true) {
    CurrentBlock = header;
    header.IncludeInDefault = includeInDefault;
    }

    public void EndBlock() {
    if (CurrentBlock == null)
    return;
    CurrentBlock.Length = template.Length - CurrentBlock.Start;
    if (CurrentBlock != header && CurrentBlock != footer)
    files.Add(CurrentBlock);
    currentBlock = null;
    }

    public virtual void Process(bool split, bool sync = true) {
    if (split) {
    EndBlock();
    String headerText = template.ToString(header.Start, header.Length);
    String footerText = template.ToString(footer.Start, footer.Length);
    String outputPath = Path.GetDirectoryName(host.TemplateFile);
    files.Reverse();
    if (!footer.IncludeInDefault)
    template.Remove(footer.Start, footer.Length);
    foreach(Block block in files) {
    String fileName = Path.Combine(outputPath, block.Name);
    String content = headerText + template.ToString(block.Start, block.Length) + footerText;
    generatedFileNames.Add(fileName);
    CreateFile(fileName, content);
    template.Remove(block.Start, block.Length);
    }
    if (!header.IncludeInDefault)
    template.Remove(header.Start, header.Length);
    }
    }

    protected virtual void CreateFile(String fileName, String content) {
    if (IsFileContentDifferent(fileName, content))
    File.WriteAllText(fileName, content);
    }

    public virtual String GetCustomToolNamespace(String fileName) {
    return null;
    }

    public virtual String DefaultProjectNamespace {
    get { return null; }
    }

    protected bool IsFileContentDifferent(String fileName, String newContent) {
    return !(File.Exists(fileName) && File.ReadAllText(fileName) == newContent);
    }

    private Manager(ITextTemplatingEngineHost host, StringBuilder template) {
    this.host = host;
    this.template = template;
    }

    private Block CurrentBlock {
    get { return currentBlock; }
    set {
    if (CurrentBlock != null)
    EndBlock();
    if (value != null)
    value.Start = template.Length;
    currentBlock = value;
    }
    }

    private class VSManager: Manager {
    private EnvDTE.ProjectItem templateProjectItem;
    private EnvDTE.DTE dte;
    private Action<String> checkOutAction;
    private Action<IEnumerable<String>> projectSyncAction;

    public override String DefaultProjectNamespace {
    get {
    return templateProjectItem.ContainingProject.Properties.Item("DefaultNamespace").Value.ToString();
    }
    }

    public override String GetCustomToolNamespace(string fileName) {
    return dte.Solution.FindProjectItem(fileName).Properties.Item("CustomToolNamespace").Value.ToString();
    }

    public override void Process(bool split, bool sync) {
    if (templateProjectItem.ProjectItems == null)
    return;
    base.Process(split, sync);
    if (sync)
    projectSyncAction.EndInvoke(projectSyncAction.BeginInvoke(generatedFileNames, null, null));
    }

    protected override void CreateFile(String fileName, String content) {
    if (IsFileContentDifferent(fileName, content)) {
    CheckoutFileIfRequired(fileName);
    File.WriteAllText(fileName, content);
    }
    }

    internal VSManager(ITextTemplatingEngineHost host, StringBuilder template)
    : base(host, template) {
    var hostServiceProvider = (IServiceProvider) host;
    if (hostServiceProvider == null)
    throw new ArgumentNullException("Could not obtain IServiceProvider");
    dte = (EnvDTE.DTE) hostServiceProvider.GetService(typeof(EnvDTE.DTE));
    if (dte == null)
    throw new ArgumentNullException("Could not obtain DTE from host");
    templateProjectItem = dte.Solution.FindProjectItem(host.TemplateFile);
    checkOutAction = (String fileName) => dte.SourceControl.CheckOutItem(fileName);
    projectSyncAction = (IEnumerable<String> keepFileNames) => ProjectSync(templateProjectItem, keepFileNames);
    }

    private static void ProjectSync(EnvDTE.ProjectItem templateProjectItem, IEnumerable<String> keepFileNames) {
    var keepFileNameSet = new HashSet<String>(keepFileNames);
    var projectFiles = new Dictionary<String, EnvDTE.ProjectItem>();
    var originalFilePrefix = Path.GetFileNameWithoutExtension(templateProjectItem.get_FileNames(0)) + ".";
    foreach(EnvDTE.ProjectItem projectItem in templateProjectItem.ProjectItems)
    projectFiles.Add(projectItem.get_FileNames(0), projectItem);

    // Remove unused items from the project
    foreach(var pair in projectFiles)
    if (!keepFileNames.Contains(pair.Key) && !(Path.GetFileNameWithoutExtension(pair.Key) + ".").StartsWith(originalFilePrefix))
    pair.Value.Delete();

    // Add missing files to the project
    foreach(String fileName in keepFileNameSet)
    if (!projectFiles.ContainsKey(fileName))
    templateProjectItem.ProjectItems.AddFromFile(fileName);
    }

    private void CheckoutFileIfRequired(String fileName) {
    var sc = dte.SourceControl;
    if (sc != null && sc.IsItemUnderSCC(fileName) && !sc.IsItemCheckedOut(fileName))
    checkOutAction.EndInvoke(checkOutAction.BeginInvoke(fileName, null, null));
    }
    }
    } #>



沪ICP备19023445号-2号
友情链接