forked from ddrilling/AsbCloudServer
72 lines
2.7 KiB
C#
72 lines
2.7 KiB
C#
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.Extensions.Configuration;
|
|
using System.Collections.Concurrent;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Threading.Tasks;
|
|
|
|
namespace AsbCloudWebApi.Middlewares;
|
|
|
|
|
|
/// <summary>
|
|
/// This is not real middleware it`s part of PermissionsMiddlware.
|
|
/// DO NOT register it in setup.cs as middleware.
|
|
/// </summary>
|
|
class UserConnectionsLimitMiddlware
|
|
{
|
|
private readonly int parallelRequestsToController;
|
|
private readonly RequestDelegate next;
|
|
private readonly byte[] responseBody;
|
|
private readonly ConcurrentDictionary<int, ConcurrentDictionary<string, int>> stat = new ();
|
|
private readonly IEnumerable<string>? controllerNames;
|
|
|
|
public UserConnectionsLimitMiddlware(RequestDelegate next, IConfiguration configuration)
|
|
{
|
|
const int parallelRequestsToControllerDefault = 8;
|
|
this.next = next;
|
|
|
|
var parallelRequestsToController = configuration.GetSection("userLimits")?.GetValue<int>("parallelRequestsToController") ?? parallelRequestsToControllerDefault;
|
|
this.parallelRequestsToController = parallelRequestsToController > 0
|
|
? parallelRequestsToController
|
|
: parallelRequestsToControllerDefault;
|
|
|
|
controllerNames = configuration.GetSection("userLimits")?.GetValue<IEnumerable<string>>("controllerNames");
|
|
|
|
var bodyText = $"<html><head><title>Too Many Requests</title></head><body><h1>Too Many Requests</h1><p>I only allow {this.parallelRequestsToController} parallel requests per user. Try again soon.</p></body></html>";
|
|
responseBody = System.Text.Encoding.UTF8.GetBytes(bodyText);
|
|
}
|
|
|
|
public async Task InvokeAsync(HttpContext context, int idUser, string controllerName)
|
|
{
|
|
if(controllerNames?.Any(n => controllerName.StartsWith(n)) == false)
|
|
{
|
|
await next(context);
|
|
return;
|
|
}
|
|
|
|
var userStat = stat.GetOrAdd(idUser, idUser => new());
|
|
var count = userStat.AddOrUpdate(controllerName, 0, (k, v) => v);
|
|
|
|
if(count + 1 < parallelRequestsToController)
|
|
{
|
|
try
|
|
{
|
|
userStat[controllerName]++;
|
|
await next(context);
|
|
}
|
|
finally
|
|
{
|
|
userStat[controllerName]--;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
context.Response.Clear();
|
|
context.Response.StatusCode = (int)System.Net.HttpStatusCode.TooManyRequests;
|
|
context.Response.Headers.RetryAfter = "1000";
|
|
context.Response.Headers.ContentType = "text/html";
|
|
await context.Response.BodyWriter.WriteAsync(responseBody);
|
|
}
|
|
}
|
|
}
|