Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

External Authentication Providers Support for Cody #6526

Merged
merged 10 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions agent/scripts/reverse-proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3

from aiohttp import web, ClientSession
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a shebang and make it executable.

Add a brief comment or module doc comment explaining that this demonstrates using external authentication providers. Link to the docs for the X-Forwarded-User configuration would be great. And gentle disclaimer that it is for testing or demoing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

from urllib.parse import urlparse
import argparse
import asyncio
import re

async def proxy_handler(request):
async with ClientSession(auto_decompress=False) as session:
print(f'Request to: {request.url}')

# Modify headers here
headers = dict(request.headers)

# Reset the Host header to use target server host instead of the proxy host
if 'Host' in headers:
headers['Host'] = urlparse(target_url).netloc.split(':')[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ancient URL archaeology but with credentials the netloc is user:pass@host:port, but I guess we don't care about that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, since this is solely for testing purposes I do not think we care too much


# 'chunked' encoding results in error 400 from Cloudflare, removing it still keeps response chunked anyway
if 'Transfer-Encoding' in headers:
del headers['Transfer-Encoding']

# Use value of 'Authorization: Bearer' to fill 'X-Forwarded-User' and remove 'Authorization' header
if 'Authorization' in headers:
match = re.match('Bearer (.*)', headers['Authorization'])
if match:
headers['X-Forwarded-User'] = match.group(1)
del headers['Authorization']

# Forward the request to target
async with session.request(
method=request.method,
url=f'{target_url}{request.path_qs}',
headers=headers,
data=await request.read()
) as response:
proxy_response = web.StreamResponse(
status=response.status,
headers=response.headers
)

await proxy_response.prepare(request)

# Stream the response back
async for chunk in response.content.iter_chunks():
await proxy_response.write(chunk[0])

await proxy_response.write_eof()
return proxy_response

app = web.Application()
app.router.add_route('*', '/{path_info:.*}', proxy_handler)

"""
Reverse Proxy Server for testing External Auth Providers in Cody

This script implements a simple reverse proxy server to facilitate testing of external authentication providers
with Cody. It's role is to simulate simulate HTTP authentication proxy setups. It handles incoming requests by:
- Forwarding them to a target Sourcegraph instance
- Converting Bearer tokens from Authorization headers into X-Forwarded-User headers
- Managing request/response streaming
- Handling header modifications required for Cloudflare compatibility

Target Sourcegraph instance needs to be configured to use HTTP authentication proxies
as described in https://sourcegraph.com/docs/admin/auth#http-authentication-proxies
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='External auth provider test proxy server')
parser.add_argument('target_url', help='Target Sourcegraph instance URL to proxy to')
parser.add_argument('proxy_port', type=int, nargs='?', default=5555,
help='Port for the proxy server (default: %(default)s)')

args = parser.parse_args()

target_url = args.target_url.rstrip('/')
port = args.proxy_port

print(f'Starting proxy server on port {port} targeting {target_url}...')
web.run_app(app, port=port)
2 changes: 1 addition & 1 deletion agent/src/AgentWorkspaceConfiguration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ export class AgentWorkspaceConfiguration implements vscode.WorkspaceConfiguratio

function mergeWithBaseConfig(config: any) {
for (const [key, value] of Object.entries(config)) {
if (typeof value === 'object') {
if (typeof value === 'object' && !Array.isArray(value)) {
const existing = _.get(baseConfig, key) ?? {}
const merged = _.merge(existing, value)
_.set(baseConfig, key, merged)
Expand Down
9 changes: 5 additions & 4 deletions agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1486,11 +1486,10 @@ export class Agent extends MessageHandler implements ExtensionClient {
config: ExtensionConfiguration,
params?: { forceAuthentication: boolean }
): Promise<AuthStatus> {
const isAuthChange = vscode_shim.isAuthenticationChange(config)
const isAuthChange = vscode_shim.isTokenOrEndpointChange(config)
vscode_shim.setExtensionConfiguration(config)

// If this is an authentication change we need to reauthenticate prior to firing events
// that update the clients
// If this is an token or endpoint change we need to save them prior to firing events that update the clients
try {
if ((isAuthChange || params?.forceAuthentication) && config.serverEndpoint) {
await authProvider.validateAndStoreCredentials(
Expand All @@ -1500,7 +1499,9 @@ export class Agent extends MessageHandler implements ExtensionClient {
},
auth: {
serverEndpoint: config.serverEndpoint,
accessToken: config.accessToken ?? null,
credentials: config.accessToken
? { token: config.accessToken, source: 'paste' }
: undefined,
},
clientState: {
anonymousUserID: config.anonymousUserID ?? null,
Expand Down
5 changes: 4 additions & 1 deletion agent/src/cli/command-auth/AuthenticatedAccount.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ export class AuthenticatedAccount {
): Promise<AuthenticatedAccount | Error> {
const graphqlClient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: {
credentials: { token: options.accessToken },
serverEndpoint: options.endpoint,
},
clientState: { anonymousUserID: null },
})
const userInfo = await graphqlClient.getCurrentUserInfo()
Expand Down
4 changes: 2 additions & 2 deletions agent/src/cli/command-auth/command-login.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async function loginAction(
: await captureAccessTokenViaBrowserRedirect(serverEndpoint, spinner)
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: token, serverEndpoint: serverEndpoint },
auth: { credentials: { token: options.accessToken }, serverEndpoint: serverEndpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down Expand Up @@ -256,7 +256,7 @@ async function promptUserAboutLoginMethod(spinner: Ora, options: LoginOptions):
try {
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: { credentials: { token: options.accessToken }, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/command-bench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ export const benchCommand = new commander.Command('bench')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: {} },
auth: {
accessToken: options.srcAccessToken,
credentials: { token: options.srcAccessToken },
serverEndpoint: options.srcEndpoint,
},
})
Expand Down
5 changes: 4 additions & 1 deletion agent/src/cli/command-bench/llm-judge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ export class LlmJudge {
localStorage.setStorage('noop')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: undefined },
auth: { accessToken: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
auth: {
credentials: { token: options.srcAccessToken },
serverEndpoint: options.srcEndpoint,
},
})
setClientCapabilities({ configuration: {}, agentCapabilities: undefined })
this.client = new SourcegraphNodeCompletionsClient()
Expand Down
5 changes: 4 additions & 1 deletion agent/src/local-e2e/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ export class LocalSGInstance {
// for checking the LLM configuration section.
this.gqlclient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { customHeaders: headers, telemetryLevel: 'agent' },
auth: { accessToken: this.params.accessToken, serverEndpoint: this.params.serverEndpoint },
auth: {
credentials: { token: this.params.accessToken },
serverEndpoint: this.params.serverEndpoint,
},
clientState: { anonymousUserID: null },
})
}
Expand Down
3 changes: 2 additions & 1 deletion agent/src/vscode-shim.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ export let extensionConfiguration: ExtensionConfiguration | undefined
export function setExtensionConfiguration(newConfig: ExtensionConfiguration): void {
extensionConfiguration = newConfig
}
export function isAuthenticationChange(newConfig: ExtensionConfiguration): boolean {

export function isTokenOrEndpointChange(newConfig: ExtensionConfiguration): boolean {
if (!extensionConfiguration) {
return true
}
Expand Down
29 changes: 27 additions & 2 deletions lib/shared/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@ export type TokenSource = 'redirect' | 'paste'
*/
export interface AuthCredentials {
serverEndpoint: string
accessToken: string | null
tokenSource?: TokenSource | undefined
credentials: HeaderCredential | TokenCredential | undefined
}

export interface HeaderCredential {
// We use function instead of property to prevent accidential top level serialization - we never want to store this data
getHeaders(): Record<string, string>
expiration: number | undefined
}

export interface TokenCredential {
token: string
source?: TokenSource
}

export interface AutoEditsTokenLimit {
Expand Down Expand Up @@ -71,6 +81,19 @@ export interface AgenticContextConfiguration {
}
}

export interface ExternalAuthCommand {
commandLine: readonly string[]
environment?: Record<string, string>
shell?: string
timeout?: number
windowsHide?: boolean
}

export interface ExternalAuthProvider {
endpoint: string
executable: ExternalAuthCommand
}

interface RawClientConfiguration {
net: NetConfiguration
codebase?: string
Expand Down Expand Up @@ -165,6 +188,8 @@ interface RawClientConfiguration {
*/
overrideServerEndpoint?: string | undefined
overrideAuthToken?: string | undefined

authExternalProviders: ExternalAuthProvider[]
}

/**
Expand Down
101 changes: 101 additions & 0 deletions lib/shared/src/configuration/auth-resolver.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import { describe, expect, test } from 'vitest'
import { type HeaderCredential, type TokenSource, isWindows } from '..'
import { resolveAuth } from './auth-resolver'
import type { ClientSecrets } from './resolver'

class TempClientSecrets implements ClientSecrets {
constructor(readonly store: Map<string, [string, TokenSource]>) {}

getToken(endpoint: string): Promise<string | undefined> {
return Promise.resolve(this.store.get(endpoint)?.[0])
}
getTokenSource(endpoint: string): Promise<TokenSource | undefined> {
return Promise.resolve(this.store.get(endpoint)?.[1])
}
}

describe('auth-resolver', () => {
test('resolve with serverEndpoint and credentials overrides', async () => {
const auth = await resolveAuth(
'sourcegraph.com',
{
authExternalProviders: [],
overrideServerEndpoint: 'my-endpoint.com',
overrideAuthToken: 'my-token',
},
new TempClientSecrets(new Map([['sourcegraph.com/', ['sgp_212323123', 'paste']]]))
)

expect(auth.serverEndpoint).toBe('my-endpoint.com/')
expect(auth.credentials).toEqual({ token: 'my-token' })
})

test('resolve with serverEndpoint override', async () => {
const auth = await resolveAuth(
'sourcegraph.com',
{
authExternalProviders: [],
overrideServerEndpoint: 'my-endpoint.com',
overrideAuthToken: undefined,
},
new TempClientSecrets(new Map([['my-endpoint.com/', ['sgp_212323123', 'paste']]]))
)

expect(auth.serverEndpoint).toBe('my-endpoint.com/')
expect(auth.credentials).toEqual({ token: 'sgp_212323123', source: 'paste' })
})

test('resolve with token override', async () => {
const auth = await resolveAuth(
'sourcegraph.com',
{
authExternalProviders: [],
overrideServerEndpoint: undefined,
overrideAuthToken: 'my-token',
},
new TempClientSecrets(new Map([['sourcegraph.com/', ['sgp_777777777', 'paste']]]))
)

expect(auth.serverEndpoint).toBe('sourcegraph.com/')
expect(auth.credentials).toEqual({ token: 'my-token' })
})

test('resolve custom auth provider', async () => {
const credentialsJson = JSON.stringify({
headers: { Authorization: 'token X' },
expiration: 1337,
})

const auth = await resolveAuth(
'sourcegraph.com',
{
authExternalProviders: [
{
endpoint: 'https://my-server.com',
executable: {
commandLine: [
isWindows() ? `echo ${credentialsJson}` : `echo '${credentialsJson}'`,
],
shell: isWindows() ? process.env.ComSpec : '/bin/bash',
timeout: 5000,
windowsHide: true,
},
},
],
overrideServerEndpoint: 'https://my-server.com',
overrideAuthToken: undefined,
},
new TempClientSecrets(new Map())
)

expect(auth.serverEndpoint).toBe('https://my-server.com/')

const headerCredential = auth.credentials as HeaderCredential
expect(headerCredential.expiration).toBe(1337)
expect(headerCredential.getHeaders()).toStrictEqual({
Authorization: 'token X',
})

expect(JSON.stringify(headerCredential)).not.toContain('token X')
})
})
Loading
Loading