using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;

namespace AsbCloudInfrastructure.Services.WellOperationImport.FileParser.StringSimilarity;

public class CosineSimilarity
{
    private const int DefaultK = 2;

    protected int K { get; }

    public CosineSimilarity(int k)
    {
        if (k <= 0)
        {
            throw new ArgumentOutOfRangeException(nameof(k), "k should be positive!");
        }

        K = k;
    }

    public CosineSimilarity() : this(DefaultK) { }

    public double Similarity(IDictionary<string, int> profile1, IDictionary<string, int> profile2)
        => DotProduct(profile1, profile2)
           / (Norm(profile1) * Norm(profile2));
    
    public Dictionary<string, int> GetProfile(string s)
    {
        var shingles = new Dictionary<string, int>();
        
        if (string.IsNullOrWhiteSpace(s))
            return shingles;

        var cleanString = Stemming(s);

        for (int i = 0; i < (cleanString.Length - K + 1); i++)
        {
            var shingle = cleanString.Substring(i, K);

            if (shingles.TryGetValue(shingle, out var old))
            {
                shingles[shingle] = old + 1;
            }
            else
            {
                shingles[shingle] = 1;
            }
        }

        return shingles;
    }

    private static string Stemming(string s)
    {
        var cleaned = Regex.Replace(s.ToLower(), "[^a-zа-я0-9]", "");
        var words = cleaned.Split(' ');
        var filteredWords = words.Where(word => word.Length > 1).ToArray();
        return string.Concat(filteredWords);
    }

    private static double Norm(IDictionary<string, int> profile)
    {
        double agg = 0;

        foreach (var entry in profile)
        {
            agg += 1.0 * entry.Value * entry.Value;
        }

        return Math.Sqrt(agg);
    }

    private static double DotProduct(IDictionary<string, int> profile1, IDictionary<string, int> profile2)
    {
        var smallProfile = profile2;
        var largeProfile = profile1;

        if (profile1.Count < profile2.Count)
        {
            smallProfile = profile1;
            largeProfile = profile2;
        }

        double agg = 0;
        foreach (var entry in smallProfile)
        {
            if (!largeProfile.TryGetValue(entry.Key, out var i)) 
                continue;

            agg += 1.0 * entry.Value * i;
        }

        return agg;
    }
}