Port shielding middleware to TypeScript (#51146)

This commit is contained in:
Peter Bengtsson 2024-06-12 13:22:03 -04:00 коммит произвёл GitHub
Родитель 67085dedfd
Коммит b58e73c51c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
13 изменённых файлов: 90 добавлений и 44 удалений

Просмотреть файл

@ -63,7 +63,7 @@ import fastlyBehavior from './fastly-behavior.js'
import mockVaPortal from './mock-va-portal.js' import mockVaPortal from './mock-va-portal.js'
import dynamicAssets from '@/assets/middleware/dynamic-assets.js' import dynamicAssets from '@/assets/middleware/dynamic-assets.js'
import contextualizeSearch from '@/search/middleware/contextualize.js' import contextualizeSearch from '@/search/middleware/contextualize.js'
import shielding from '@/shielding/middleware/index.js' import shielding from '@/shielding/middleware'
import tracking from '@/tracking/middleware/index.js' import tracking from '@/tracking/middleware/index.js'
import { MAX_REQUEST_TIMEOUT } from '@/frame/lib/constants.js' import { MAX_REQUEST_TIMEOUT } from '@/frame/lib/constants.js'

Просмотреть файл

@ -1,3 +1,7 @@
import type { Response, NextFunction } from 'express'
import { ExtendedRequest } from '@/types'
const INVALID_HEADER_KEYS = [ const INVALID_HEADER_KEYS = [
// Next.js will pick this up and override the status code. // Next.js will pick this up and override the status code.
// We don't want that to happen because `x-invoke-status: 203` can // We don't want that to happen because `x-invoke-status: 203` can
@ -7,7 +11,11 @@ const INVALID_HEADER_KEYS = [
'x-invoke-status', 'x-invoke-status',
] ]
export default function handleInvalidNextPaths(req, res, next) { export default function handleInvalidNextPaths(
req: ExtendedRequest,
res: Response,
next: NextFunction,
) {
const header = INVALID_HEADER_KEYS.find((key) => req.headers[key]) const header = INVALID_HEADER_KEYS.find((key) => req.headers[key])
if (header) { if (header) {
// There's no point attempting to set a cache-control on this. // There's no point attempting to set a cache-control on this.

Просмотреть файл

@ -1,9 +1,16 @@
import statsd from '#src/observability/lib/statsd.js' import type { Response, NextFunction } from 'express'
import { defaultCacheControl } from '#src/frame/middleware/cache-control.js'
import statsd from '@/observability/lib/statsd.js'
import { defaultCacheControl } from '@/frame/middleware/cache-control.js'
import { ExtendedRequest } from '@/types'
const STATSD_KEY = 'middleware.handle_invalid_nextjs_paths' const STATSD_KEY = 'middleware.handle_invalid_nextjs_paths'
export default function handleInvalidNextPaths(req, res, next) { export default function handleInvalidNextPaths(
req: ExtendedRequest,
res: Response,
next: NextFunction,
) {
// For example, `/_next/bin/junk.css`. // For example, `/_next/bin/junk.css`.
// The reason for depending on checking NODE_ENV is that in development, // The reason for depending on checking NODE_ENV is that in development,
// the Nextjs server will send things like /_next/static/webpack/... // the Nextjs server will send things like /_next/static/webpack/...

Просмотреть файл

@ -1,4 +1,7 @@
import { defaultCacheControl } from '#src/frame/middleware/cache-control.js' import type { Response, NextFunction } from 'express'
import { defaultCacheControl } from '@/frame/middleware/cache-control.js'
import { ExtendedRequest } from '@/types'
// We'll check if the current request path is one of these, or ends with // We'll check if the current request path is one of these, or ends with
// one of these. // one of these.
@ -31,7 +34,7 @@ const JUNK_BASENAMES = new Set([
'.env', '.env',
]) ])
function isJunkPath(path) { function isJunkPath(path: string) {
if (JUNK_PATHS.has(path)) return true if (JUNK_PATHS.has(path)) return true
for (const junkPath of JUNK_ENDS) { for (const junkPath of JUNK_ENDS) {
@ -42,8 +45,8 @@ function isJunkPath(path) {
const basename = path.split('/').pop() const basename = path.split('/').pop()
// E.g. `/billing/.env.local` or `/billing/.env_sample` // E.g. `/billing/.env.local` or `/billing/.env_sample`
if (/^\.env(.|_)[\w.]+/.test(basename)) return true if (basename && /^\.env(.|_)[\w.]+/.test(basename)) return true
if (JUNK_BASENAMES.has(basename)) return true if (basename && JUNK_BASENAMES.has(basename)) return true
// Prevent various malicious injection attacks targeting Next.js // Prevent various malicious injection attacks targeting Next.js
if (path.match(/^\/_next[^/]/) || path === '/_next/data' || path === '/_next/data/') { if (path.match(/^\/_next[^/]/) || path === '/_next/data' || path === '/_next/data/') {
@ -60,7 +63,11 @@ function isJunkPath(path) {
return false return false
} }
export default function handleInvalidPaths(req, res, next) { export default function handleInvalidPaths(
req: ExtendedRequest,
res: Response,
next: NextFunction,
) {
if (isJunkPath(req.path)) { if (isJunkPath(req.path)) {
// We can all the CDN to cache these responses because they're // We can all the CDN to cache these responses because they're
// they're not going to suddenly work in the next deployment. // they're not going to suddenly work in the next deployment.

Просмотреть файл

@ -1,7 +1,10 @@
import statsd from '#src/observability/lib/statsd.js' import type { Response, NextFunction } from 'express'
import { allTools } from '#src/tools/lib/all-tools.js'
import { allPlatforms } from '#src/tools/lib/all-platforms.js' import { ExtendedRequest } from '@/types'
import { defaultCacheControl } from '#src/frame/middleware/cache-control.js' import statsd from '@/observability/lib/statsd.js'
import { allTools } from '@/tools/lib/all-tools.js'
import { allPlatforms } from '@/tools/lib/all-platforms.js'
import { defaultCacheControl } from '@/frame/middleware/cache-control.js'
const STATSD_KEY = 'middleware.handle_invalid_querystring_values' const STATSD_KEY = 'middleware.handle_invalid_querystring_values'
@ -29,14 +32,19 @@ const RECOGNIZED_VALUES = {
// //
const RECOGNIZED_VALUES_KEYS = new Set(Object.keys(RECOGNIZED_VALUES)) const RECOGNIZED_VALUES_KEYS = new Set(Object.keys(RECOGNIZED_VALUES))
export default function handleInvalidQuerystringValues(req, res, next) { export default function handleInvalidQuerystringValues(
req: ExtendedRequest,
res: Response,
next: NextFunction,
) {
const { method, query } = req const { method, query } = req
if (method === 'GET' || method === 'HEAD') { if (method === 'GET' || method === 'HEAD') {
for (const [key, value] of Object.entries(query)) { for (const [key, value] of Object.entries(query)) {
if (RECOGNIZED_VALUES_KEYS.has(key)) { if (RECOGNIZED_VALUES_KEYS.has(key)) {
const validValues = RECOGNIZED_VALUES[key] const validValues = RECOGNIZED_VALUES[key as keyof typeof RECOGNIZED_VALUES]
const values = Array.isArray(query[key]) ? query[key] : [query[key]] const value = query[key]
if (values.some((value) => !validValues.includes(value))) { const values = Array.isArray(value) ? value : [value]
if (values.some((value) => typeof value === 'string' && !validValues.includes(value))) {
if (process.env.NODE_ENV === 'development') { if (process.env.NODE_ENV === 'development') {
console.warn( console.warn(
'Warning! Invalid query string *value* detected. %O is not one of %O', 'Warning! Invalid query string *value* detected. %O is not one of %O',
@ -46,7 +54,7 @@ export default function handleInvalidQuerystringValues(req, res, next) {
} }
// Some value is not recognized. Redirect to the current URL // Some value is not recognized. Redirect to the current URL
// but with that query string key removed. // but with that query string key removed.
const sp = new URLSearchParams(query) const sp = new URLSearchParams(query as any)
sp.delete(key) sp.delete(key)
defaultCacheControl(res) defaultCacheControl(res)

Просмотреть файл

@ -1,5 +1,8 @@
import statsd from '#src/observability/lib/statsd.js' import type { Response, NextFunction } from 'express'
import { noCacheControl, defaultCacheControl } from '#src/frame/middleware/cache-control.js'
import statsd from '@/observability/lib/statsd.js'
import { noCacheControl, defaultCacheControl } from '@/frame/middleware/cache-control.js'
import { ExtendedRequest } from '@/types'
const STATSD_KEY = 'middleware.handle_invalid_querystrings' const STATSD_KEY = 'middleware.handle_invalid_querystrings'
@ -37,7 +40,11 @@ const RECOGNIZED_KEYS_BY_ANY = new Set([
'utm_campaign', 'utm_campaign',
]) ])
export default function handleInvalidQuerystrings(req, res, next) { export default function handleInvalidQuerystrings(
req: ExtendedRequest,
res: Response,
next: NextFunction,
) {
const { method, query, path } = req const { method, query, path } = req
if (method === 'GET' || method === 'HEAD') { if (method === 'GET' || method === 'HEAD') {
const originalKeys = Object.keys(query) const originalKeys = Object.keys(query)
@ -99,7 +106,7 @@ export default function handleInvalidQuerystrings(req, res, next) {
) )
} }
defaultCacheControl(res) defaultCacheControl(res)
const sp = new URLSearchParams(query) const sp = new URLSearchParams(query as any)
keys.forEach((key) => sp.delete(key)) keys.forEach((key) => sp.delete(key))
let newURL = req.path let newURL = req.path
if (sp.toString()) newURL += `?${sp}` if (sp.toString()) newURL += `?${sp}`

Просмотреть файл

@ -22,9 +22,16 @@
import fs from 'fs' import fs from 'fs'
import { errorCacheControl } from '#src/frame/middleware/cache-control.js' import type { Response, NextFunction } from 'express'
export default function handleOldNextDataPaths(req, res, next) { import { ExtendedRequest } from '@/types'
import { errorCacheControl } from '@/frame/middleware/cache-control.js'
export default function handleOldNextDataPaths(
req: ExtendedRequest,
res: Response,
next: NextFunction,
) {
if (req.path.startsWith('/_next/data/') && !req.path.startsWith('/_next/data/development/')) { if (req.path.startsWith('/_next/data/') && !req.path.startsWith('/_next/data/development/')) {
const requestBuildId = req.path.split('/')[3] const requestBuildId = req.path.split('/')[3]
if (requestBuildId !== getCurrentBuildID()) { if (requestBuildId !== getCurrentBuildID()) {
@ -35,7 +42,7 @@ export default function handleOldNextDataPaths(req, res, next) {
return next() return next()
} }
let _buildId let _buildId: string
function getCurrentBuildID() { function getCurrentBuildID() {
// Simple memoization // Simple memoization
if (!_buildId) { if (!_buildId) {

Просмотреть файл

@ -1,12 +1,12 @@
import express from 'express' import express from 'express'
import handleInvalidQuerystrings from './handle-invalid-query-strings.js' import handleInvalidQuerystrings from './handle-invalid-query-strings'
import handleInvalidPaths from './handle-invalid-paths.js' import handleInvalidPaths from './handle-invalid-paths'
import handleOldNextDataPaths from './handle-old-next-data-paths.js' import handleOldNextDataPaths from './handle-old-next-data-paths'
import handleInvalidQuerystringValues from './handle-invalid-query-string-values.js' import handleInvalidQuerystringValues from './handle-invalid-query-string-values'
import handleInvalidNextPaths from './handle-invalid-nextjs-paths.js' import handleInvalidNextPaths from './handle-invalid-nextjs-paths'
import handleInvalidHeaders from './handle-invalid-headers.js' import handleInvalidHeaders from './handle-invalid-headers'
import rateLimit from './rate-limit.js' import rateLimit from './rate-limit'
const router = express.Router() const router = express.Router()

Просмотреть файл

@ -1,7 +1,9 @@
import type { Request } from 'express'
import rateLimit from 'express-rate-limit' import rateLimit from 'express-rate-limit'
import statsd from '#src/observability/lib/statsd.js' import statsd from '@/observability/lib/statsd.js'
import { noCacheControl } from '#src/frame/middleware/cache-control.js' import { noCacheControl } from '@/frame/middleware/cache-control.js'
const EXPIRES_IN_AS_SECONDS = 60 const EXPIRES_IN_AS_SECONDS = 60
@ -33,7 +35,7 @@ export default rateLimit({
// the `x-forwarded-for` is always the origin IP with a port number // the `x-forwarded-for` is always the origin IP with a port number
// attached. E.g. `75.40.90.27:56675, 169.254.129.1` // attached. E.g. `75.40.90.27:56675, 169.254.129.1`
// This port number portion changes with every request, so we strip it. // This port number portion changes with every request, so we strip it.
ip = ip.replace(ipv4WithPort, '$1') ip = (ip || '').replace(ipv4WithPort, '$1')
return ip return ip
}, },
@ -112,7 +114,7 @@ const MISC_KEYS = [
* @param {Request} req * @param {Request} req
* @returns boolean * @returns boolean
*/ */
function isSuspiciousRequest(req) { function isSuspiciousRequest(req: Request) {
const keys = Object.keys(req.query) const keys = Object.keys(req.query)
// Since this function can only speculate by query strings (at the // Since this function can only speculate by query strings (at the

Просмотреть файл

@ -1,6 +1,6 @@
import { describe, expect, test } from 'vitest' import { describe, expect, test } from 'vitest'
import { get } from '#src/tests/helpers/e2etest.js' import { get } from '@/tests/helpers/e2etest.js'
describe('invalid headers', () => { describe('invalid headers', () => {
test('400 if containing x-invoke-status (instead of redirecting)', async () => { test('400 if containing x-invoke-status (instead of redirecting)', async () => {

Просмотреть файл

@ -1,6 +1,6 @@
import { describe, expect, test } from 'vitest' import { describe, expect, test } from 'vitest'
import { get } from '#src/tests/helpers/e2etest.js' import { get } from '@/tests/helpers/e2etest.js'
describe('invalid query string values', () => { describe('invalid query string values', () => {
test.each(['platform', 'tool'])('%a key', async (key) => { test.each(['platform', 'tool'])('%a key', async (key) => {

Просмотреть файл

@ -1,11 +1,11 @@
import { describe, expect, test } from 'vitest' import { describe, expect, test } from 'vitest'
import { get } from '#src/tests/helpers/e2etest.js' import { get } from '@/tests/helpers/e2etest.js'
import { import {
MAX_UNFAMILIAR_KEYS_BAD_REQUEST, MAX_UNFAMILIAR_KEYS_BAD_REQUEST,
MAX_UNFAMILIAR_KEYS_REDIRECT, MAX_UNFAMILIAR_KEYS_REDIRECT,
} from '#src/shielding/middleware/handle-invalid-query-strings.js' } from '@/shielding/middleware/handle-invalid-query-strings.js'
const alpha = Array.from(Array(26)).map((e, i) => i + 65) const alpha = Array.from(Array(26)).map((e, i) => i + 65)
const alphabet = alpha.map((x) => String.fromCharCode(x)) const alphabet = alpha.map((x) => String.fromCharCode(x))
@ -82,7 +82,7 @@ describe('invalid query strings', () => {
}) })
}) })
function randomCharacters(length) { function randomCharacters(length: number) {
let s = '' let s = ''
const pool = `abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789` const pool = `abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789`
while (s.length < length) { while (s.length < length) {

Просмотреть файл

@ -1,7 +1,7 @@
import { describe, expect, test } from 'vitest' import { describe, expect, test } from 'vitest'
import { SURROGATE_ENUMS } from '#src/frame/middleware/set-fastly-surrogate-key.js' import { SURROGATE_ENUMS } from '@/frame/middleware/set-fastly-surrogate-key.js'
import { get } from '#src/tests/helpers/e2etest.js' import { get } from '@/tests/helpers/e2etest.js'
describe('honeypotting', () => { describe('honeypotting', () => {
test('any GET with survey-vote and survey-token query strings is 400', async () => { test('any GET with survey-vote and survey-token query strings is 400', async () => {