import type { Duration } from "./duration.ts";import { ms } from "./duration.ts";import type { Algorithm, MultiRegionContext } from "./types.ts";import { Ratelimit } from "./ratelimit.ts";import type { Redis } from "./types.ts";
export type MultiRegionRatelimitConfig = { redis: Redis[]; limiter: Algorithm<MultiRegionContext>; prefix?: string;};
export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> { constructor(config: MultiRegionRatelimitConfig) { super({ prefix: config.prefix, limiter: config.limiter, ctx: { redis: config.redis }, }); }
static fixedWindow( tokens: number, window: Duration, ): Algorithm<MultiRegionContext> { const windowDuration = ms(window); const script = ` local key = KEYS[1] local id = ARGV[1] local window = ARGV[2] redis.call("SADD", key, id) local members = redis.call("SMEMBERS", key) if #members == 1 then -- The first time this key is set, the value will be 1. -- So we only need the expire command once redis.call("PEXPIRE", key, window) end return members`;
return async function (ctx: MultiRegionContext, identifier: string) { const requestID = crypto.randomUUID(); const bucket = Math.floor(Date.now() / windowDuration); const key = [identifier, bucket].join(":");
const dbs: { redis: Redis; request: Promise<string[]> }[] = ctx.redis.map( (redis) => ({ redis, request: redis.eval( script, [key], [requestID, windowDuration], ) as Promise<string[]>, }), );
const firstResponse = await Promise.any(dbs.map((s) => s.request));
const usedTokens = firstResponse.length;
const remaining = tokens - usedTokens - 1;
async function sync() { const individualIDs = await Promise.all(dbs.map((s) => s.request)); const allIDs = Array.from( new Set(individualIDs.flatMap((_) => _)).values(), );
for (const db of dbs) { const ids = await db.request; if (ids.length >= tokens) { continue; } const diff = allIDs.filter((id) => !ids.includes(id)); if (diff.length === 0) { continue; }
await db.redis.sadd(key, ...allIDs); } }
return { success: remaining > 0, limit: tokens, remaining, reset: (bucket + 1) * windowDuration, pending: sync(), }; }; }
static slidingWindow( tokens: number, window: Duration, ): Algorithm<MultiRegionContext> { const windowSize = ms(window); const script = ` local currentKey = KEYS[1] -- identifier including prefixes local previousKey = KEYS[2] -- key of the previous bucket local tokens = tonumber(ARGV[1]) -- tokens per window local now = ARGV[2] -- current timestamp in milliseconds local window = ARGV[3] -- interval in milliseconds local requestID = ARGV[4] -- uuid for this request
local currentMembers = redis.call("SMEMBERS", currentKey) local requestsInCurrentWindow = #currentMembers local previousMembers = redis.call("SMEMBERS", previousKey) local requestsInPreviousWindow = #previousMembers
local percentageInCurrent = ( now % window) / window if requestsInPreviousWindow * ( 1 - percentageInCurrent ) + requestsInCurrentWindow >= tokens then return {currentMembers, previousMembers} end
redis.call("SADD", currentKey, requestID) table.insert(currentMembers, requestID) if requestsInCurrentWindow == 0 then -- The first time this key is set, the value will be 1. -- So we only need the expire command once redis.call("PEXPIRE", currentKey, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second end return {currentMembers, previousMembers} `; const windowDuration = ms(window);
return async function (ctx: MultiRegionContext, identifier: string) { const requestID = crypto.randomUUID(); const now = Date.now();
const currentWindow = Math.floor(now / windowSize); const currentKey = [identifier, currentWindow].join(":"); const previousWindow = currentWindow - windowSize; const previousKey = [identifier, previousWindow].join(":");
const dbs: { redis: Redis; request: Promise<[string[], string[]]> }[] = ctx.redis.map((redis) => ({ redis, request: redis.eval( script, [currentKey, previousKey], [tokens, now, windowDuration, requestID], ) as Promise<[string[], string[]]>, }));
const percentageInCurrent = (now % windowDuration) / windowDuration; const [current, previous] = await Promise.any(dbs.map((s) => s.request));
const usedTokens = previous.length * (1 - percentageInCurrent) + current.length;
const remaining = tokens - usedTokens;
async function sync() { const [individualIDs] = await Promise.all(dbs.map((s) => s.request)); const allIDs = Array.from( new Set(individualIDs.flatMap((_) => _)).values(), );
for (const db of dbs) { const [ids] = await db.request; if (ids.length >= tokens) { continue; } const diff = allIDs.filter((id) => !ids.includes(id)); if (diff.length === 0) { continue; }
await db.redis.sadd(currentKey, ...allIDs); } }
return { success: remaining > 0, limit: tokens, remaining, reset: (currentWindow + 1) * windowDuration, pending: sync(), }; }; }}