diff --git a/.env.example b/.env.example index 2387db4d69..4a05a2416f 100644 --- a/.env.example +++ b/.env.example @@ -20,4 +20,9 @@ BUILD_RECORDS_MAX_CONCURRENT=100 BUILD_RECORDS_MIN_TIME= # Set to true to enable the /fastly-cache-test route for debugging Fastly headers -ENABLE_FASTLY_TESTING= \ No newline at end of file +ENABLE_FASTLY_TESTING= + +# Needed to auth for AI search +CSE_COPILOT_SECRET= +CSE_COPILOT_ENDPOINT=https://cse-copilot-staging.service.iad.github.net + diff --git a/package-lock.json b/package-lock.json index 6014631180..f293871e63 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11759,9 +11759,10 @@ "license": "MIT" }, "node_modules/punycode": { - "version": "2.1.1", + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", "dev": true, - "license": "MIT", "engines": { "node": ">=6" } @@ -14272,9 +14273,9 @@ } }, "node_modules/type-fest": { - "version": "4.23.0", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-4.23.0.tgz", - "integrity": "sha512-ZiBujro2ohr5+Z/hZWHESLz3g08BBdrdLMieYFULJO+tWc437sn8kQsWLJoZErY8alNhxre9K4p3GURAG11n+w==", + "version": "4.26.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-4.26.1.tgz", + "integrity": "sha512-yOGpmOAL7CkKe/91I5O3gPICmJNLJ1G4zFYVAsRHg7M64biSnPtRj0WNQt++bRkjYOqjWXrhnUw1utzmVErAdg==", "engines": { "node": ">=16" }, diff --git a/src/frame/middleware/api.ts b/src/frame/middleware/api.ts index 62d77ef619..33ff9c8bb0 100644 --- a/src/frame/middleware/api.ts +++ b/src/frame/middleware/api.ts @@ -3,6 +3,7 @@ import { createProxyMiddleware } from 'http-proxy-middleware' import events from '@/events/middleware.js' import anchorRedirect from '@/rest/api/anchor-redirect.js' +import aiSearch from '@/search/middleware/ai-search' import search from '@/search/middleware/search-routes.js' import pageInfo from '@/pageinfo/middleware' import pageList from '@/pagelist/middleware' @@ -23,6 +24,23 @@ router.use('/pagelist', pageList) // local laptop, they don't have an Elasticsearch. Neither a running local // server or the known credentials to a remote Elasticsearch. Whenever // that's the case, they can just HTTP proxy to the production server. +if (process.env.CSE_COPILOT_ENDPOINT || process.env.NODE_ENV === 'test') { + router.use('/ai-search', aiSearch) +} else { + console.log( + 'Proxying AI Search requests to docs.github.com. To use the cse-copilot endpoint, set the CSE_COPILOT_ENDPOINT environment variable.', + ) + router.use( + '/ai-search', + createProxyMiddleware({ + target: 'https://docs.github.com', + changeOrigin: true, + pathRewrite: function (path, req: ExtendedRequest) { + return req.originalUrl + }, + }), + ) +} if (process.env.ELASTICSEARCH_URL) { router.use('/search', search) } else { diff --git a/src/search/lib/ai-search-proxy.ts b/src/search/lib/ai-search-proxy.ts new file mode 100644 index 0000000000..9ddfdecb02 --- /dev/null +++ b/src/search/lib/ai-search-proxy.ts @@ -0,0 +1,78 @@ +import { Request, Response } from 'express' +import got from 'got' +import { getHmacWithEpoch } from '@/search/lib/helpers/get-cse-copilot-auth' +import { getCSECopilotSource } from '#src/search/lib/helpers/cse-copilot-docs-versions.js' + +export const aiSearchProxy = async (req: Request, res: Response) => { + const { query, version, language } = req.body + const errors = [] + + // Validate request body + if (!query) { + errors.push({ message: `Missing required key 'query' in request body` }) + } else if (typeof query !== 'string') { + errors.push({ message: `Invalid 'query' in request body. Must be a string` }) + } + if (!version) { + errors.push({ message: `Missing required key 'version' in request body` }) + } + if (!language) { + errors.push({ message: `Missing required key 'language' in request body` }) + } + + let docsSource = '' + try { + docsSource = getCSECopilotSource(version, language) + } catch (error: any) { + errors.push({ message: error?.message || 'Invalid version or language' }) + } + + if (errors.length) { + res.status(400).json({ errors }) + return + } + + const body = { + chat_context: 'defaults', + docs_source: docsSource, + query, + stream: true, + } + + try { + const stream = got.post(`${process.env.CSE_COPILOT_ENDPOINT}/answers`, { + json: body, + headers: { + Authorization: getHmacWithEpoch(), + 'Content-Type': 'application/json', + }, + isStream: true, + }) + + // Set response headers + res.setHeader('Content-Type', 'application/x-ndjson') + res.flushHeaders() + + // Pipe the got stream directly to the response + stream.pipe(res) + + // Handle stream errors + stream.on('error', (error) => { + console.error('Error streaming from cse-copilot:', error) + // Only send error response if headers haven't been sent + if (!res.headersSent) { + res.status(500).json({ errors: [{ message: 'Internal server error' }] }) + } else { + res.end() + } + }) + + // Ensure response ends when stream ends + stream.on('end', () => { + res.end() + }) + } catch (error) { + console.error('Error posting /answers to cse-copilot:', error) + res.status(500).json({ errors: [{ message: 'Internal server error' }] }) + } +} diff --git a/src/search/lib/helpers/cse-copilot-docs-versions.ts b/src/search/lib/helpers/cse-copilot-docs-versions.ts new file mode 100644 index 0000000000..9b96aa9ddc --- /dev/null +++ b/src/search/lib/helpers/cse-copilot-docs-versions.ts @@ -0,0 +1,44 @@ +// Versions used by cse-copilot +import { allVersions } from '@/versions/lib/all-versions' +const CSE_COPILOT_DOCS_VERSIONS = ['dotcom', 'ghec', 'ghes'] + +// Languages supported by cse-copilot +const DOCS_LANGUAGES = ['en'] +export function supportedCSECopilotLanguages() { + return DOCS_LANGUAGES +} + +export function getCSECopilotSource( + version: (typeof CSE_COPILOT_DOCS_VERSIONS)[number], + language: (typeof DOCS_LANGUAGES)[number], +) { + const cseCopilotDocsVersion = getMiscBaseNameFromVersion(version) + if (!CSE_COPILOT_DOCS_VERSIONS.includes(cseCopilotDocsVersion)) { + throw new Error( + `Invalid 'version' in request body: '${version}'. Must be one of: ${CSE_COPILOT_DOCS_VERSIONS.join(', ')}`, + ) + } + if (!DOCS_LANGUAGES.includes(language)) { + throw new Error( + `Invalid 'language' in request body '${language}'. Must be one of: ${DOCS_LANGUAGES.join(', ')}`, + ) + } + return `docs_${version}_${language}` +} + +function getMiscBaseNameFromVersion(Version: string): string { + const miscBaseName = + Object.values(allVersions).find( + (info) => + info.shortName === Version || + info.plan === Version || + info.miscVersionName === Version || + info.currentRelease === Version, + )?.miscBaseName || '' + + if (!miscBaseName) { + return '' + } + + return miscBaseName +} diff --git a/src/search/lib/helpers/get-cse-copilot-auth.ts b/src/search/lib/helpers/get-cse-copilot-auth.ts new file mode 100644 index 0000000000..a636852452 --- /dev/null +++ b/src/search/lib/helpers/get-cse-copilot-auth.ts @@ -0,0 +1,24 @@ +import crypto from 'crypto' + +// github/cse-copilot's API requires an HMAC-SHA256 signature with each request +export function getHmacWithEpoch() { + const epochTime = getEpochTime().toString() + // CSE_COPILOT_SECRET needs to be set for the api-ai-search tests to work + if (process.env.NODE_ENV === 'test') { + process.env.CSE_COPILOT_SECRET = 'mock-secret' + } + if (!process.env.CSE_COPILOT_SECRET) { + throw new Error('CSE_COPILOT_SECRET is not defined') + } + const hmac = generateHmacSha256(process.env.CSE_COPILOT_SECRET, epochTime) + return `${epochTime}.${hmac}` +} + +// In seconds +function getEpochTime(): number { + return Math.floor(Date.now() / 1000) +} + +function generateHmacSha256(key: string, data: string): string { + return crypto.createHmac('sha256', key).update(data).digest('hex') +} diff --git a/src/search/middleware/ai-search.ts b/src/search/middleware/ai-search.ts new file mode 100644 index 0000000000..f2cf89fbc7 --- /dev/null +++ b/src/search/middleware/ai-search.ts @@ -0,0 +1,20 @@ +import express, { Request, Response } from 'express' + +import catchMiddlewareError from '#src/observability/middleware/catch-middleware-error.js' +import { aiSearchProxy } from '../lib/ai-search-proxy' + +const router = express.Router() + +router.post( + '/v1', + catchMiddlewareError(async (req: Request, res: Response) => { + await aiSearchProxy(req, res) + }), +) + +// Redirect to most recent version +router.post('/', (req, res) => { + res.redirect(307, req.originalUrl.replace('/ai-search', '/ai-search/v1')) +}) + +export default router diff --git a/src/search/tests/api-ai-search.ts b/src/search/tests/api-ai-search.ts new file mode 100644 index 0000000000..9b66f2c6db --- /dev/null +++ b/src/search/tests/api-ai-search.ts @@ -0,0 +1,148 @@ +import { expect, test, describe, beforeAll, afterAll } from 'vitest' + +import { post } from 'src/tests/helpers/e2etest.js' +import { startMockServer, stopMockServer } from '@/tests/mocks/start-mock-server' + +describe('AI Search Routes', () => { + beforeAll(() => { + startMockServer() + }) + afterAll(() => stopMockServer()) + + test('/api/ai-search/v1 should handle a successful response', async () => { + let apiBody = { query: 'How do I create a Repository?', language: 'en', version: 'dotcom' } + + const response = await fetch('http://localhost:4000/api/ai-search/v1', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(apiBody), + }) + + expect(response.ok).toBe(true) + expect(response.headers.get('content-type')).toBe('application/x-ndjson') + expect(response.headers.get('transfer-encoding')).toBe('chunked') + + if (!response.body) { + throw new Error('ReadableStream not supported in this environment.') + } + + const decoder = new TextDecoder('utf-8') + const reader = response.body.getReader() + let done = false + const chunks = [] + + while (!done) { + const { value, done: readerDone } = await reader.read() + done = readerDone + + if (value) { + // Decode the Uint8Array chunk into a string + const chunkStr = decoder.decode(value, { stream: true }) + chunks.push(chunkStr) + } + } + + // Combine all chunks into a single string + const fullResponse = chunks.join('') + // Split the response into individual chunk lines + const chunkLines = fullResponse.split('\n').filter((line) => line.trim() !== '') + + // Assertions: + + // 1. First chunk should be the SOURCES chunk + expect(chunkLines.length).toBeGreaterThan(0) + const firstChunkMatch = chunkLines[0].match(/^Chunk: (.+)$/) + expect(firstChunkMatch).not.toBeNull() + + const sourcesChunk = JSON.parse(firstChunkMatch?.[1] || '') + expect(sourcesChunk).toHaveProperty('chunkType', 'SOURCES') + expect(sourcesChunk).toHaveProperty('sources') + expect(Array.isArray(sourcesChunk.sources)).toBe(true) + expect(sourcesChunk.sources.length).toBe(3) + + // 2. Subsequent chunks should be MESSAGE_CHUNKs + for (let i = 1; i < chunkLines.length; i++) { + const line = chunkLines[i] + const messageChunk = JSON.parse(line) + expect(messageChunk).toHaveProperty('chunkType', 'MESSAGE_CHUNK') + expect(messageChunk).toHaveProperty('text') + expect(typeof messageChunk.text).toBe('string') + } + + // 3. Verify the complete message is expected + const expectedMessage = + 'Creating a repository on GitHub is something you should already know how to do :shrug:' + const receivedMessage = chunkLines + .slice(1) + .map((line) => JSON.parse(line).text) + .join('') + expect(receivedMessage).toBe(expectedMessage) + }) + + test('should handle validation errors: query missing', async () => { + let body = { language: 'en', version: 'dotcom' } + const response = await post('/api/ai-search/v1', { + body: JSON.stringify(body), + headers: { 'Content-Type': 'application/json' }, + }) + + const responseBody = JSON.parse(response.body) + + expect(response.ok).toBe(false) + expect(responseBody['errors']).toEqual([ + { message: `Missing required key 'query' in request body` }, + ]) + }) + + test('should handle validation errors: language missing', async () => { + let body = { query: 'example query', version: 'dotcom' } + const response = await post('/api/ai-search/v1', { + body: JSON.stringify(body), + headers: { 'Content-Type': 'application/json' }, + }) + + const responseBody = JSON.parse(response.body) + + expect(response.ok).toBe(false) + expect(responseBody['errors']).toEqual([ + { message: `Missing required key 'language' in request body` }, + { message: `Invalid 'language' in request body 'undefined'. Must be one of: en` }, + ]) + }) + + test('should handle validation errors: version missing', async () => { + let body = { query: 'example query', language: 'en' } + const response = await post('/api/ai-search/v1', { + body: JSON.stringify(body), + headers: { 'Content-Type': 'application/json' }, + }) + + const responseBody = JSON.parse(response.body) + + expect(response.ok).toBe(false) + expect(responseBody['errors']).toEqual([ + { message: `Missing required key 'version' in request body` }, + { + message: `Invalid 'version' in request body: 'undefined'. Must be one of: dotcom, ghec, ghes`, + }, + ]) + }) + + test('should handle multiple validation errors: query missing, invalid language and version', async () => { + let body = { language: 'fr', version: 'fpt' } + const response = await post('/api/ai-search/v1', { + body: JSON.stringify(body), + headers: { 'Content-Type': 'application/json' }, + }) + + const responseBody = JSON.parse(response.body) + + expect(response.ok).toBe(false) + expect(responseBody['errors']).toEqual([ + { message: `Missing required key 'query' in request body` }, + { + message: `Invalid 'language' in request body 'fr'. Must be one of: en`, + }, + ]) + }) +}) diff --git a/src/tests/mocks/cse-copilot-mock.ts b/src/tests/mocks/cse-copilot-mock.ts new file mode 100644 index 0000000000..e2b75a17c3 --- /dev/null +++ b/src/tests/mocks/cse-copilot-mock.ts @@ -0,0 +1,71 @@ +import { Request, Response } from 'express' + +// Prefix used for mocking. This can be any value +export const CSE_COPILOT_PREFIX = 'cse-copilot' + +export function cseCopilotPostAnswersMock(req: Request, res: Response) { + // Set headers for chunked transfer and encoding + res.setHeader('Content-Type', 'application/json; charset=utf-8') + res.setHeader('Transfer-Encoding', 'chunked') + + // Define the SOURCES chunk + const sourcesChunk = { + chunkType: 'SOURCES', + sources: [ + { + title: 'Creating a new repository', + url: 'https://docs.github.com/en/repositories/creating-and-managing-repositories/creating-a-new-repository', + index: '/en/repositories/creating-and-managing-repositories/creating-a-new-repository', + }, + { + title: 'Creating and managing repositories', + url: 'https://docs.github.com/en/repositories/creating-and-managing-repositories', + index: '/en/repositories/creating-and-managing-repositories', + }, + { + title: 'GitHub Terms of Service', + url: 'https://docs.github.com/en/site-policy/github-terms/github-terms-of-service', + index: '/en/site-policy/github-terms/github-terms-of-service', + }, + ], + } + + // Function to send a chunk with proper encoding + const sendEncodedChunk = (data: any, isLast = false) => { + const prefix = isLast ? '' : '\n' // Optionally, add delimiters if needed + const buffer = Buffer.from(prefix + data, 'utf-8') + res.write(buffer) + } + + // Send the SOURCES chunk + sendEncodedChunk(`Chunk: ${JSON.stringify(sourcesChunk)}\n\n`) + + // Define the message to be sent in chunks + const message = + 'Creating a repository on GitHub is something you should already know how to do :shrug:' + + // Split the message into words (or adjust the splitting logic as needed) + const words = message.split(' ') + + let index = 0 + + const sendChunk = () => { + if (index < words.length) { + const word = words[index] + const isLastWord = index === words.length - 1 + const chunk = { + chunkType: 'MESSAGE_CHUNK', + text: word + (isLastWord ? '' : ' '), // Add space if not the last word + } + sendEncodedChunk(`${JSON.stringify(chunk)}\n`) + index++ + sendChunk() // Adjust the delay as needed + } else { + // End the response after all chunks are sent + res.end() + } + } + + // Start sending MESSAGE_CHUNKs + sendChunk() +} diff --git a/src/tests/mocks/start-mock-server.ts b/src/tests/mocks/start-mock-server.ts new file mode 100644 index 0000000000..93c01f5ba2 --- /dev/null +++ b/src/tests/mocks/start-mock-server.ts @@ -0,0 +1,73 @@ +/* When testing API routes via an integration test, e.g. + +const res = await post('/api/', { + body: JSON.stringify(api_body), + headers: { 'Content-Type': 'application/json' }, +}) + +expect(res.status).toBe(200) + +The `api/` may call an external URL. + +We are unable to use `nock` in this circumstance since we run the server in a separate instance. + +Instead, we can use the `startMockServer` helper to start a mock server that will intercept the request and return a canned response. + +In order for this to work you MUST use a process.env variable for the URL you are calling, + +e.g. `process.env.CSE_COPILOT_ENDPOINT` + +You should override the variable in the overrideEnvForTesting function in this file. +*/ + +import express from 'express' +import { CSE_COPILOT_PREFIX, cseCopilotPostAnswersMock } from './cse-copilot-mock' + +// Define the default port for the mock server +const MOCK_SERVER_PORT = 3012 + +// Construct the server URL using the defined port +const serverUrl = `http://localhost:${MOCK_SERVER_PORT}` + +// Variable to hold the server instance +let server: any = null + +// Override environment variables for testing purposes +export function overrideEnvForTesting() { + process.env.CSE_COPILOT_ENDPOINT = `${serverUrl}/${CSE_COPILOT_PREFIX}` +} + +// Function to start the mock server +export function startMockServer(port = MOCK_SERVER_PORT) { + const app = express() + app.use(express.json()) + + // Define your mock routes here + app.post(`/${CSE_COPILOT_PREFIX}/answers`, cseCopilotPostAnswersMock) + + // Start the server and store the server instance + server = app.listen(port, () => { + console.log(`Mock server is running on port ${port}`) + }) +} + +// Function to stop the mock server +export function stopMockServer(): Promise { + return new Promise((resolve, reject) => { + if (server) { + server.close((err: any) => { + if (err) { + console.error('Error stopping the mock server:', err) + reject(err) + } else { + console.log('Mock server has been stopped.') + server = null + resolve() + } + }) + } else { + console.warn('Mock server is not running.') + resolve() + } + }) +} diff --git a/src/tests/vitest.setup.ts b/src/tests/vitest.setup.ts index f0d68a978f..3e7ce07708 100644 --- a/src/tests/vitest.setup.ts +++ b/src/tests/vitest.setup.ts @@ -1,4 +1,5 @@ import { main } from 'src/frame/start-server' +import { overrideEnvForTesting } from './mocks/start-mock-server' let teardownHappened = false type PromiseType> = T extends Promise ? U : never @@ -7,6 +8,7 @@ type Server = PromiseType> let server: Server | undefined export async function setup() { + overrideEnvForTesting() server = await main() }