using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Metadata; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Text.Json.Serialization; using System.Text.Json; using System.Threading; using System.Threading.Tasks; namespace AsbCloudDb { public static class EFExtensions { private static readonly JsonSerializerOptions jsonSerializerOptions = new() { AllowTrailingCommas = true, WriteIndented = true, NumberHandling = JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.AllowNamedFloatingPointLiterals, }; public static Microsoft.EntityFrameworkCore.Metadata.Builders.PropertyBuilder HasJsonConversion( this Microsoft.EntityFrameworkCore.Metadata.Builders.PropertyBuilder builder) => HasJsonConversion(builder, jsonSerializerOptions); public static Microsoft.EntityFrameworkCore.Metadata.Builders.PropertyBuilder HasJsonConversion( this Microsoft.EntityFrameworkCore.Metadata.Builders.PropertyBuilder builder, JsonSerializerOptions jsonSerializerOptions) { builder.HasConversion( s => JsonSerializer.Serialize(s, jsonSerializerOptions), s => JsonSerializer.Deserialize(s, jsonSerializerOptions)!); ValueComparer valueComparer = new ( (a,b) => (a!=null) && (b != null) ? a.GetHashCode() == b.GetHashCode() : (a == null) && (b == null), i => (i == null) ?-1 : i.GetHashCode(), i => i); builder.Metadata.SetValueComparer(valueComparer); return builder; } static Dictionary QueryFactories { get; set; } = []; static QueryStringFactory GetQueryStringFactory(DbSet dbSet) where T : class { var t = typeof(T); var factory = (QueryStringFactory?)QueryFactories.GetValueOrDefault(t); if (factory is null) { factory = new QueryStringFactory(dbSet); QueryFactories.Add(t, factory); } return factory; } public static Task ExecInsertOrUpdateAsync(this Microsoft.EntityFrameworkCore.Infrastructure.DatabaseFacade database, DbSet dbSet, IEnumerable items, CancellationToken token) where T : class { var factory = GetQueryStringFactory(dbSet); var query = factory.MakeInsertOrUpdateSql(items); return database.ExecuteSqlRawAsync(query, token); } public static Task ExecInsertOrIgnoreAsync(this Microsoft.EntityFrameworkCore.Infrastructure.DatabaseFacade database, DbSet dbSet, IEnumerable items, CancellationToken token) where T : class { var factory = GetQueryStringFactory(dbSet); var query = factory.MakeInsertOrIgnoreSql(items); return database.ExecuteSqlRawAsync(query, token); } public static Task ExecInsertAsync(this Microsoft.EntityFrameworkCore.Infrastructure.DatabaseFacade database, DbSet dbSet, IEnumerable items, CancellationToken token) where T : class { var factory = GetQueryStringFactory(dbSet); var query = factory.MakeInsertSql(items); return database.ExecuteSqlRawAsync(query, token); } public static string GetTableName(this DbSet dbSet) where T : class { var factory = GetQueryStringFactory(dbSet); return factory.TableName; } public static IEnumerable GetColumnsNames(this DbSet dbSet) where T : class { var factory = GetQueryStringFactory(dbSet); return factory.Columns; } public static EntityEntry Upsert(this DbSet dbSet, T value) where T : class { return dbSet.Contains(value) ? dbSet.Update(value) : dbSet.Add(value); } public static (int updated, int inserted) UpsertRange(this DbSet dbSet, IEnumerable values) where T : class { (int updated, int inserted) stat = (0, 0); foreach (var value in values) if (dbSet.Contains(value)) { stat.updated++; dbSet.Update(value); } else { stat.inserted++; dbSet.Add(value); } return stat; } public static IQueryable SkipTake(this IQueryable query, int? skip, int? take) { if (skip > 0) query = query.Skip((int)skip); if (take > 0) query = query.Take((int)take); return query; } } interface IQueryStringFactory { } class QueryStringFactory : IQueryStringFactory where T : class { private readonly string insertHeader; private readonly string pk; private readonly string conflictBody; private readonly IEnumerable getters; public string TableName { get; } public IEnumerable Columns { get; } public QueryStringFactory(DbSet dbSet) { var properties = dbSet.EntityType.GetProperties(); var pkColsNames = dbSet.EntityType.FindPrimaryKey()?.Properties.Select(p => p.GetColumnName()); pk = pkColsNames is null ? string.Empty : $"({string.Join(", ", pkColsNames)})"; TableName = dbSet.EntityType.GetTableName()!; getters = properties .Where(p => !p.IsShadowProperty()) .Select(p => p.GetGetter()).ToList(); Columns = properties.Select(p => $"\"{p.GetColumnName()}\""); var columnsString = $"({string.Join(", ", Columns)})"; insertHeader = $"INSERT INTO {TableName} {columnsString} VALUES "; var excludedUpdateSet = string.Join(", ", Columns.Select(n => $"{n} = excluded.{n}")); conflictBody = $" ON CONFLICT {pk} DO UPDATE SET {excludedUpdateSet};"; } public string MakeInsertOrIgnoreSql(IEnumerable items) { var builder = new StringBuilder(insertHeader, 7); BuildRows(builder, items); builder.Append(" ON CONFLICT DO NOTHING;"); return builder.ToString(); } public string MakeInsertOrUpdateSql(IEnumerable items) { var builder = new StringBuilder(insertHeader, 7); BuildRows(builder, items); if (string.IsNullOrEmpty(pk)) builder.Append(" ON CONFLICT DO NOTHING;"); else builder.AppendLine(conflictBody); return builder.ToString(); } public string MakeInsertSql(IEnumerable items) { var builder = new StringBuilder(insertHeader, 7); BuildRows(builder, items); builder.Append(';'); return builder.ToString(); } private StringBuilder BuildRows(StringBuilder builder, IEnumerable items) { var list = items.ToList(); for (var i = 0; i < list.Count; i++) { if (i > 0) builder.Append(','); BuildRow(builder, list[i]); } return builder; } private StringBuilder BuildRow(StringBuilder builder, T item) { var values = getters.Select(getter => FormatValue(getter.GetClrValue(item))); builder.Append('('); builder.AppendJoin(",", values); builder.Append(')'); return builder; } private static string FormatValue(object? v) => v switch { null => "NULL", string vStr => $"'{EscapeCurlyBraces(vStr)}'", DateTime vDate => $"'{FormatDateValue(vDate)}'", DateTimeOffset vDate => $"'{FormatDateValue(vDate.UtcDateTime)}'", IFormattable vFormattable => FormatFormattableValue(vFormattable), _ => $"'{EscapeCurlyBraces(JsonSerializer.Serialize(v))}'", }; private static string EscapeCurlyBraces(string vStr) { var result = vStr .Replace("{", "{{") .Replace("}", "}}"); return result; } private static string FormatFormattableValue(IFormattable v) => v switch { double vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), float vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), decimal vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), int vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), short vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), uint vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), ushort vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture), _ => v.ToString(null, System.Globalization.CultureInfo.InvariantCulture), }; private static string FormatDateValue(DateTime vDate) { if (vDate.Kind == DateTimeKind.Unspecified) vDate = DateTime.SpecifyKind(vDate, DateTimeKind.Utc); return vDate.ToUniversalTime().ToString("yyyy-MM-dd HH:mm:ss.ffffff zzz"); } } public class DateOnlyJsonConverter : JsonConverter { public override DateOnly Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { return DateOnly.FromDateTime(reader.GetDateTime()); } public override void Write(Utf8JsonWriter writer, DateOnly value, JsonSerializerOptions options) { var isoDate = value.ToString("O"); writer.WriteStringValue(isoDate); } } }