using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; namespace AsbCloudDb { public static class EFExtentions { static Dictionary QueryFactories { get; set; } = new(); static QueryStringFactory GetQueryStringFactory(DbSet dbSet) where T : class { var t = typeof(T); var factory = (QueryStringFactory?)QueryFactories.GetValueOrDefault(t); if (factory is null) { factory = new QueryStringFactory(dbSet); QueryFactories.Add(t, factory); } return factory; } public static Task ExecInsertOrUpdateAsync(this Microsoft.EntityFrameworkCore.Infrastructure.DatabaseFacade database, DbSet dbSet, IEnumerable items, CancellationToken token) where T : class { var factory = GetQueryStringFactory(dbSet); var query = factory.MakeInsertOrUpdateSql(items); return database.ExecuteSqlRawAsync(query, token); } public static string GetTableName(this DbSet dbSet) where T : class { var factory = GetQueryStringFactory(dbSet); return factory.TableName; } public static IEnumerable GetColumnsNames(this DbSet dbSet) where T : class { var factory = GetQueryStringFactory(dbSet); return factory.Columns; } public static Microsoft.EntityFrameworkCore.ChangeTracking.EntityEntry Upsert(this DbSet dbSet, T value) where T : class { return dbSet.Contains(value) ? dbSet.Update(value) : dbSet.Add(value); } public static (int updated, int inserted) UpsertRange(this DbSet dbSet, IEnumerable values) where T : class { (int updated, int inserted) stat = (0, 0); foreach (var value in values) if (dbSet.Contains(value)) { stat.updated++; dbSet.Update(value); } else { stat.inserted++; dbSet.Add(value); } return stat; } } interface IQueryStringFactory { } class QueryStringFactory : IQueryStringFactory where T : class { private readonly string insertHeader; private readonly string pk; private readonly string conflictBody; private readonly IEnumerable getters; public string TableName { get; } public IEnumerable Columns { get; } public QueryStringFactory(DbSet dbset) { var properties = dbset.EntityType.GetProperties(); var pkColsNames = dbset.EntityType.FindPrimaryKey()?.Properties.Select(p => p.GetColumnBaseName()); pk = pkColsNames is null ? string.Empty : $"({string.Join(", ", pkColsNames)})"; TableName = dbset.EntityType.GetTableName()!; getters = properties.Select(p => p.GetGetter()); Columns = properties.Select(p => $"\"{p.GetColumnBaseName()}\""); var colunmsString = $"({string.Join(", ", Columns)})"; insertHeader = $"INSERT INTO {TableName} {colunmsString} VALUES "; var excludedUpdateSet = string.Join(", ", Columns.Select(n => $"{n} = excluded.{n}")); conflictBody = $" ON CONFLICT {pk} DO UPDATE SET {excludedUpdateSet};"; } public string MakeInsertOrUpdateSql(IEnumerable items) { var builder = new StringBuilder(insertHeader, 7); BuildRows(builder, items); if (string.IsNullOrEmpty(pk)) builder.Append(" ON CONFLICT DO NOTHING;"); else builder.AppendLine(conflictBody); return builder.ToString(); } private StringBuilder BuildRows(StringBuilder builder, IEnumerable items) { var list = items.ToList(); for (var i = 0; i < list.Count; i++) { if (i > 0) builder.Append(','); BuildRow(builder, list[i]); } return builder; } private StringBuilder BuildRow(StringBuilder builder, T item) { var values = getters.Select(getter => FormatValue(getter.GetClrValue(item))); builder.Append('('); builder.AppendJoin(",", values); builder.Append(')'); return builder; } private static string FormatValue(object? v) => v switch { string vStr => $"'{vStr}'", DateTime vDate => $"'{FormatDateValue(vDate)}'", DateTimeOffset vDate => $"'{FormatDateValue(vDate.UtcDateTime)}'", IFormattable vFormattable => FormatFormattableValue(vFormattable), _ => System.Text.Json.JsonSerializer.Serialize(v), }; private static string FormatFormattableValue(IFormattable v) => v switch { double vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), float vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), decimal vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), int vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), short vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), uint vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), ushort vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), _ => v.ToString(null, System.Globalization.CultureInfo.InvariantCulture), }; private static string FormatDateValue(DateTime vDate) { if (vDate.Kind == DateTimeKind.Unspecified) vDate = DateTime.SpecifyKind(vDate, DateTimeKind.Utc); return vDate.ToUniversalTime().ToString("yyyy-MM-dd HH:mm:ss.ffffff zzz"); } } }