using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace AsbCloudDb
{
    public static class EFExtentions
    {
        static Dictionary<Type, IQueryStringFactory> QueryFactories { get; set; } = new();

        static QueryStringFactory<T> GetQueryStringFactory<T>(DbSet<T> dbSet)
            where T : class
        {
            var t = typeof(T);
            var factory = (QueryStringFactory<T>)QueryFactories.GetValueOrDefault(t);
            if (factory is null)
            {
                factory = new QueryStringFactory<T>(dbSet);
                QueryFactories.Add(t, factory);
            }

            return factory;
        }

        public static Task<int> ExecInsertOrUpdateAsync<T>(this Microsoft.EntityFrameworkCore.Infrastructure.DatabaseFacade database, DbSet<T> dbSet, IEnumerable<T> items, CancellationToken token)
            where T : class
        {
            var factory = GetQueryStringFactory(dbSet);
            var query = factory.MakeInsertOrUpdateSql(items);

            return database.ExecuteSqlRawAsync(query, token);
        }

        public static string GetTableName<T>(this DbSet<T> dbSet)
        where T : class
        {
            var factory = GetQueryStringFactory(dbSet);
            return factory.TableName;
        }

        public static IEnumerable<string> GetColumnsNames<T>(this DbSet<T> dbSet)
            where T : class
        {
            var factory = GetQueryStringFactory(dbSet);
            return factory.Columns;
        }
    }

    interface IQueryStringFactory { }

    class QueryStringFactory<T> : IQueryStringFactory
        where T : class
    {
        private readonly string insertHeader;
        private readonly string pk;
        private readonly string conflictBody;
        private readonly IEnumerable<IClrPropertyGetter> getters;

        public string TableName { get; }
        public IEnumerable<string> Columns { get; }

        public QueryStringFactory(DbSet<T> dbset)
        {
            var properties = dbset.EntityType.GetProperties();
            var pkColsNames = dbset.EntityType.FindPrimaryKey()?.Properties.Select(p => p.GetColumnBaseName());
            pk = pkColsNames is null ? string.Empty : $"({string.Join(", ", pkColsNames)})";

            TableName = dbset.EntityType.GetTableName();
            getters = properties.Select(p => p.GetGetter());
            Columns = properties.Select(p => $"\"{p.GetColumnBaseName()}\"");
            var colunmsString = $"({string.Join(", ", Columns)})";

            insertHeader = $"INSERT INTO {TableName} {colunmsString} VALUES ";
            var excludedUpdateSet = string.Join(", ", Columns.Select(n => $"{n} = excluded.{n}"));
            conflictBody = $" ON CONFLICT {pk} DO UPDATE SET {excludedUpdateSet};";
        }

        public string MakeInsertOrUpdateSql(IEnumerable<T> items)
        {
            /* EXAMPLE:
                INSERT INTO the_table (id, column_1, column_2) 
                VALUES (1, 'A', 'X'), (2, 'B', 'Y'), (3, 'C', 'Z')
                ON CONFLICT (id) DO UPDATE 
                  SET column_1 = excluded.column_1, 
                      column_2 = excluded.column_2;
             */
            var builder = new StringBuilder(insertHeader, 7);
            BuildRows(builder, items);
            if (string.IsNullOrEmpty(pk))
                builder.Append(" ON CONFLICT DO NOTHING;");
            else
                builder.AppendLine(conflictBody);

            return builder.ToString();
        }

        private StringBuilder BuildRows(StringBuilder builder, IEnumerable<T> items)
        {
            var list = items.ToList();
            for (var i = 0; i < list.Count; i++)
            {
                if (i > 0)
                    builder.Append(',');
                BuildRow(builder, list[i]);
            }
            return builder;
        }

        private StringBuilder BuildRow(StringBuilder builder, T item)
        {
            var values = getters.Select(getter => FormatValue(getter.GetClrValue(item)));
            builder.Append('(');
            builder.AppendJoin(",", values);
            builder.Append(')');
            return builder;
        }

        private static string FormatValue(object v)
        => v switch
        {
            string vStr => $"'{vStr}'",
            DateTime vDate => $"'{FormatDateValue(vDate)}'",
            DateTimeOffset vDate => $"'{FormatDateValue(vDate.UtcDateTime)}'",
            IFormattable vFormattable => FormatFormattableValue(vFormattable),
            _ => System.Text.Json.JsonSerializer.Serialize(v),
        };

        private static string FormatFormattableValue(IFormattable v)
        => v switch
        {
            double vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture),
            float vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture),
            decimal vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture),
            int vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture),
            short vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture),
            uint vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture),
            ushort vt => vt.ToString(System.Globalization.CultureInfo.InvariantCulture),
            _ => v.ToString(null, System.Globalization.CultureInfo.InvariantCulture),
        };

        private static string FormatDateValue(DateTime vDate)
        {
            if (vDate.Kind == DateTimeKind.Unspecified)
                vDate = DateTime.SpecifyKind(vDate, DateTimeKind.Utc);
            return vDate.ToUniversalTime().ToString("yyyy-MM-dd HH:mm:ss.ffffff zzz");
        }

    }
}