using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; namespace AsbCloudDb { public static class EFExtentions { static Dictionary QueryFactories { get; set; } = new Dictionary(); static IQueryStringFactory GetQueryStringFactory(DbSet dbset) where T : class { var t = typeof(T); QueryStringFactory factory = (QueryStringFactory)QueryFactories.GetValueOrDefault(t); if (factory is null) { factory = new QueryStringFactory(dbset); QueryFactories.Add(t, factory); } return factory; } public static Task ExecInsertOrUpdateAsync(this Microsoft.EntityFrameworkCore.Infrastructure.DatabaseFacade database, DbSet dbset, IEnumerable items, CancellationToken token) where T : class { var factory = (QueryStringFactory)GetQueryStringFactory(dbset); var query = factory.MakeInsertOrUpdateSql(items); return database.ExecuteSqlRawAsync(query, token); } } interface IQueryStringFactory{} class QueryStringFactory : IQueryStringFactory where T : class { private readonly string pk; private readonly string tableName; private readonly string colunmsString; private readonly string conflictUpdateSet; private readonly IEnumerable getters; public QueryStringFactory(DbSet dbset) { var ps = dbset.EntityType.GetProperties(); var pkColsNames = dbset.EntityType.FindPrimaryKey()?.Properties.Select(p => p.GetColumnBaseName()); pk = pkColsNames is null ? string.Empty : $"({string.Join(", ", pkColsNames)})"; tableName = dbset.EntityType.GetTableName(); getters = ps.Select(p => p.GetGetter()); var colNames = ps.Select(p => $"\"{p.GetColumnBaseName()}\""); colunmsString = $"({string.Join(", ", colNames)})"; conflictUpdateSet = string.Join(", ", colNames.Select(n => $"{n} = excluded.{n}")); } public string MakeInsertOrUpdateSql(IEnumerable items) { /* EXAMPLE: INSERT INTO the_table (id, column_1, column_2) VALUES (1, 'A', 'X'), (2, 'B', 'Y'), (3, 'C', 'Z') ON CONFLICT (id) DO UPDATE SET column_1 = excluded.column_1, column_2 = excluded.column_2; */ var sqlBuilder = new StringBuilder("INSERT INTO ", 7); sqlBuilder.Append(tableName); sqlBuilder.Append(colunmsString); sqlBuilder.AppendLine(" VALUES "); sqlBuilder.Append(MakeQueryValues(items)); sqlBuilder.AppendLine(" ON CONFLICT "); if (string.IsNullOrEmpty(pk)) { sqlBuilder.Append("DO NOTHING;"); } else { sqlBuilder.Append(pk); sqlBuilder.Append(" DO UPDATE SET "); sqlBuilder.AppendLine(conflictUpdateSet); sqlBuilder.Append(';'); } return sqlBuilder.ToString(); } private string MakeQueryValues(IEnumerable items) { var rows = items.Select(item => MakeRow(item)); return string.Join(",", rows); } private string MakeRow(T item) { var values = getters.Select(getter => FormatValue(getter.GetClrValue(item))); return $"({string.Join(",", values)})"; } private static string FormatValue(object v) => v switch { string vStr => $"'{vStr}'", DateTime vDate => $"'{FormatDateValue(vDate)}'", IFormattable vFormattable=> FormatFormattableValue(vFormattable), _ => System.Text.Json.JsonSerializer.Serialize(v), }; private static string FormatFormattableValue(IFormattable v) => v switch { double vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), float vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), decimal vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), int vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), short vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), uint vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), ushort vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), _ => v.ToString(null, System.Globalization.CultureInfo.InvariantCulture), }; private static string FormatDateValue(DateTime vDate) { if (vDate.Kind == DateTimeKind.Unspecified) vDate = DateTime.SpecifyKind(vDate, DateTimeKind.Utc); return vDate.ToUniversalTime().ToString("yyyy-MM-dd HH:mm:ss.ffffff zzz"); } } }