forked from ddrilling/AsbCloudServer
72 lines
2.7 KiB
C#
72 lines
2.7 KiB
C#
|
using Microsoft.AspNetCore.Http;
|
|||
|
using Microsoft.Extensions.Configuration;
|
|||
|
using System.Collections.Concurrent;
|
|||
|
using System.Collections.Generic;
|
|||
|
using System.Linq;
|
|||
|
using System.Threading.Tasks;
|
|||
|
|
|||
|
namespace AsbCloudWebApi.Middlewares
|
|||
|
{
|
|||
|
#nullable enable
|
|||
|
/// <summary>
|
|||
|
/// This is not real middleware it`s part of PermissionsMiddlware.
|
|||
|
/// DO NOT register it in setup.cs as middleware.
|
|||
|
/// </summary>
|
|||
|
class UserConnectionsLimitMiddlware
|
|||
|
{
|
|||
|
private readonly RequestDelegate next;
|
|||
|
private readonly int parallelRequestsToController;
|
|||
|
private readonly byte[] body;
|
|||
|
private readonly ConcurrentDictionary<int, ConcurrentDictionary<string, int>> stat = new ();
|
|||
|
private readonly IEnumerable<string>? controllerNames;
|
|||
|
|
|||
|
public UserConnectionsLimitMiddlware(RequestDelegate next, IConfiguration configuration)
|
|||
|
{
|
|||
|
this.next = next;
|
|||
|
|
|||
|
var parallelRequestsToController = configuration.GetSection("userLimits")?.GetValue<int>("parallelRequestsToController") ?? 5;
|
|||
|
this.parallelRequestsToController = parallelRequestsToController > 0
|
|||
|
? parallelRequestsToController
|
|||
|
: 5;
|
|||
|
|
|||
|
controllerNames = configuration.GetSection("userLimits")?.GetValue<IEnumerable<string>>("controllerNames");
|
|||
|
|
|||
|
var bodyText = $"<html><head><title>Too Many Requests</title></head><body><h1>Too Many Requests</h1><p>I only allow {parallelRequestsToController} parallel requests per user. Try again soon.</p></body></html>";
|
|||
|
body = System.Text.Encoding.UTF8.GetBytes(bodyText);
|
|||
|
}
|
|||
|
|
|||
|
public async Task InvokeAsync(HttpContext context, int idUser, string controllerName)
|
|||
|
{
|
|||
|
if(controllerNames?.Any(n => controllerName.StartsWith(n)) == false)
|
|||
|
{
|
|||
|
await next(context);
|
|||
|
return;
|
|||
|
}
|
|||
|
|
|||
|
var userStat = stat.GetOrAdd(idUser, idUser => new());
|
|||
|
var count = userStat.AddOrUpdate(controllerName, 1, (key, value) => value + 1);
|
|||
|
if(count < parallelRequestsToController)
|
|||
|
{
|
|||
|
try
|
|||
|
{
|
|||
|
await next(context);
|
|||
|
}
|
|||
|
finally
|
|||
|
{
|
|||
|
userStat[controllerName]--;
|
|||
|
}
|
|||
|
}
|
|||
|
else
|
|||
|
{
|
|||
|
context.Response.Clear();
|
|||
|
context.Response.StatusCode = (int)System.Net.HttpStatusCode.TooManyRequests;
|
|||
|
|
|||
|
context.Response.Headers.RetryAfter = "1000";
|
|||
|
context.Response.Headers.ContentType = "text/html";
|
|||
|
await context.Response.BodyWriter.WriteAsync(body);
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
#nullable disable
|
|||
|
}
|