Skip to content

Commit

Permalink
Refactor AuthCredentials type
Browse files Browse the repository at this point in the history
  • Loading branch information
pkukielka committed Jan 10, 2025
1 parent 87d2228 commit a0ad157
Show file tree
Hide file tree
Showing 24 changed files with 119 additions and 107 deletions.
4 changes: 3 additions & 1 deletion agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,9 @@ export class Agent extends MessageHandler implements ExtensionClient {
},
auth: {
serverEndpoint: config.serverEndpoint,
accessTokenOrHeaders: config.accessToken ?? null,
credentials: config.accessToken
? { token: config.accessToken, source: 'paste' }
: undefined,
},
clientState: {
anonymousUserID: config.anonymousUserID ?? null,
Expand Down
5 changes: 4 additions & 1 deletion agent/src/cli/command-auth/AuthenticatedAccount.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ export class AuthenticatedAccount {
): Promise<AuthenticatedAccount | Error> {
const graphqlClient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
auth: {
credentials: { token: options.accessToken },
serverEndpoint: options.endpoint,
},
clientState: { anonymousUserID: null },
})
const userInfo = await graphqlClient.getCurrentUserInfo()
Expand Down
4 changes: 2 additions & 2 deletions agent/src/cli/command-auth/command-login.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async function loginAction(
: await captureAccessTokenViaBrowserRedirect(serverEndpoint, spinner)
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessTokenOrHeaders: token, serverEndpoint: serverEndpoint },
auth: { credentials: { token: options.accessToken }, serverEndpoint: serverEndpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down Expand Up @@ -256,7 +256,7 @@ async function promptUserAboutLoginMethod(spinner: Ora, options: LoginOptions):
try {
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
auth: { credentials: { token: options.accessToken }, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/command-bench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ export const benchCommand = new commander.Command('bench')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: {} },
auth: {
accessTokenOrHeaders: options.srcAccessToken,
credentials: { token: options.srcAccessToken },
serverEndpoint: options.srcEndpoint,
},
})
Expand Down
5 changes: 4 additions & 1 deletion agent/src/cli/command-bench/llm-judge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ export class LlmJudge {
localStorage.setStorage('noop')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: undefined },
auth: { accessTokenOrHeaders: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
auth: {
credentials: { token: options.srcAccessToken },
serverEndpoint: options.srcEndpoint,
},
})
setClientCapabilities({ configuration: {}, agentCapabilities: undefined })
this.client = new SourcegraphNodeCompletionsClient()
Expand Down
2 changes: 1 addition & 1 deletion agent/src/local-e2e/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export class LocalSGInstance {
this.gqlclient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { customHeaders: headers, telemetryLevel: 'agent' },
auth: {
accessTokenOrHeaders: this.params.accessToken,
credentials: { token: this.params.accessToken },
serverEndpoint: this.params.serverEndpoint,
},
clientState: { anonymousUserID: null },
Expand Down
24 changes: 20 additions & 4 deletions lib/shared/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,29 @@ import type { ReadonlyDeep } from './utils'
* A redirect flow is initiated by the user clicking a link in the browser, while a paste flow is initiated by the user
* manually entering the access from into the VsCode App.
*/
export type TokenSource = 'redirect' | 'paste' | 'custom-auth-provider'
export type TokenSource = 'redirect' | 'paste'

export type AuthHeaders = Record<string, string>
const nonSerializableSymbol = Symbol('nonSerializable')
export type NonSerializableRecord<K extends keyof any, T> = Record<K, T> & {
[nonSerializableSymbol]: never
}

/**
* The user's authentication credentials, which are stored separately from the rest of the
* configuration.
*/
export interface AuthCredentials {
serverEndpoint: string
tokenSource?: TokenSource | undefined
accessTokenOrHeaders: string | AuthHeaders | null
credentials: HeaderCredential | TokenCredential | undefined
}

export interface HeaderCredential {
headers: NonSerializableRecord<string, string>
}

export interface TokenCredential {
token: string
source?: TokenSource
}

export interface AutoEditsTokenLimit {
Expand Down Expand Up @@ -86,6 +97,11 @@ export interface ExternalAuthProvider {
executable: ExternalAuthCommand
}

export interface ExternalAuthProviderResult {
headers: NonSerializableRecord<string, string>
expiration: number
}

interface RawClientConfiguration {
net: NetConfiguration
codebase?: string
Expand Down
51 changes: 24 additions & 27 deletions lib/shared/src/configuration/resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type {
AuthCredentials,
ClientConfiguration,
ExternalAuthCommand,
NonSerializableRecord,
TokenSource,
} from '../configuration'
import { logError } from '../logger'
Expand Down Expand Up @@ -103,7 +104,7 @@ async function executeCommand(cmd: ExternalAuthCommand): Promise<string> {
async function getExternalProviderAuthHeaders(
serverEndpoint: string,
clientConfiguration: ClientConfiguration
): Promise<Record<string, string> | undefined> {
): Promise<NonSerializableRecord<string, string> | undefined> {
// Check for external auth provider for this endpoint
const externalProvider = clientConfiguration.authExternalProviders.find(
provider => normalizeServerEndpointURL(provider.endpoint) === serverEndpoint
Expand All @@ -122,39 +123,35 @@ export async function resolveAuth(
clientConfiguration: ClientConfiguration,
clientSecrets: ClientSecrets
): Promise<AuthCredentials> {
let accessTokenOrHeaders = null
let tokenSource: TokenSource | undefined = undefined

const serverEndpoint = normalizeServerEndpointURL(
clientConfiguration.overrideServerEndpoint || endpoint
)

// We must not throw here, because that would result in the `resolvedConfig` observable
// terminating and all callers receiving no further config updates.
const loadTokenFn = () =>
clientSecrets.getToken(serverEndpoint).catch(error => {
throw new Error(
`Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}`
)
})

if (clientConfiguration.overrideAuthToken) {
accessTokenOrHeaders = clientConfiguration.overrideAuthToken
} else
try {
const authHeaders = await getExternalProviderAuthHeaders(serverEndpoint, clientConfiguration)
if (authHeaders) {
accessTokenOrHeaders = authHeaders
tokenSource = 'custom-auth-provider'
} else {
accessTokenOrHeaders = (await loadTokenFn()) || null
tokenSource = await clientSecrets.getTokenSource(serverEndpoint).catch(_ => undefined)
}
} catch (error) {
return { credentials: { token: clientConfiguration.overrideAuthToken }, serverEndpoint }
}

const authHeaders = await getExternalProviderAuthHeaders(serverEndpoint, clientConfiguration).catch(
error => {
throw new Error(`Failed to execute external auth command: ${error}`)
}
)
if (authHeaders) {
return { credentials: { headers: authHeaders }, serverEndpoint }
}

const token = await clientSecrets.getToken(serverEndpoint).catch(error => {
throw new Error(
`Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}`
)
})

return { accessTokenOrHeaders, serverEndpoint, tokenSource }
return {
credentials: token
? { token, source: await clientSecrets.getTokenSource(serverEndpoint) }
: undefined,
serverEndpoint,
}
}

async function resolveConfiguration({
Expand All @@ -180,7 +177,7 @@ async function resolveConfiguration({
// We don't want to throw here, because that would cause the observable to terminate and
// all callers receiving no further config updates.
logError('resolveConfiguration', `Error resolving configuration: ${error}`)
const auth = { accessTokenOrHeaders: null, serverEndpoint, tokenSource: undefined }
const auth = { credentials: undefined, serverEndpoint }
return { configuration: clientConfiguration, clientState, auth, isReinstall }
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/shared/src/experimentation/FeatureFlagProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describe('FeatureFlagProvider', () => {
beforeAll(() => {
vi.useFakeTimers()
mockResolvedConfig({
auth: { accessTokenOrHeaders: null, serverEndpoint: 'https://example.com' },
auth: { credentials: undefined, serverEndpoint: 'https://example.com' },
})
mockAuthStatus(AUTH_STATUS_FIXTURE_AUTHED)
})
Expand Down
13 changes: 7 additions & 6 deletions lib/shared/src/sourcegraph-api/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ export function toPartialUtf8String(buf: Buffer): { str: string; buf: Buffer } {

export function addAuthHeaders(auth: AuthCredentials, headers: Headers, url: URL): void {
// We want to be sure we sent authorization headers only to the valid endpoint
if (auth.accessTokenOrHeaders && url.host === new URL(auth.serverEndpoint).host) {
if (typeof auth.accessTokenOrHeaders === 'string') {
headers.set('Authorization', `token ${auth.accessTokenOrHeaders}`)
} else {
// Add headers as-is when accessTokenOrHeaders is a record of headers
for (const [key, value] of Object.entries(auth.accessTokenOrHeaders)) {
if (auth.credentials && url.host === new URL(auth.serverEndpoint).host) {
if ('token' in auth.credentials) {
headers.set('Authorization', `token ${auth.credentials.token}`)
} else if ('headers' in auth.credentials) {
for (const [key, value] of Object.entries(auth.credentials.headers)) {
headers.set(key, value)
}
} else {
console.error('Cannot add headers: neither token nor headers found')
}
}
}
22 changes: 9 additions & 13 deletions vscode/src/auth/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,17 @@ export async function showSignInMenu(
const { configuration } = await currentResolvedConfig()
const auth = await resolveAuth(selectedEndpoint, configuration, secretStorage)

let authStatus = auth.accessTokenOrHeaders
let authStatus = auth.credentials
? await authProvider.validateAndStoreCredentials(auth, 'store-if-valid')
: undefined

if (!authStatus?.authenticated) {
const newToken = await showAccessTokenInputBox(selectedEndpoint)
if (!newToken) {
const token = await showAccessTokenInputBox(selectedEndpoint)
if (!token) {
return
}
authStatus = await authProvider.validateAndStoreCredentials(
{
serverEndpoint: selectedEndpoint,
accessTokenOrHeaders: newToken,
tokenSource: 'paste',
},
{ serverEndpoint: selectedEndpoint, credentials: { token, source: 'paste' } },
'store-if-valid'
)
}
Expand Down Expand Up @@ -229,12 +225,12 @@ const LoginMenuOptionItems = [
]

async function signinMenuForInstanceUrl(instanceUrl: string): Promise<void> {
const accessToken = await showAccessTokenInputBox(instanceUrl)
if (!accessToken) {
const token = await showAccessTokenInputBox(instanceUrl)
if (!token) {
return
}
const authStatus = await authProvider.validateAndStoreCredentials(
{ serverEndpoint: instanceUrl, accessTokenOrHeaders: accessToken, tokenSource: 'paste' },
{ serverEndpoint: instanceUrl, credentials: { token, source: 'paste' } },
'store-if-valid'
)
telemetryRecorder.recordEvent('cody.auth.signin.token', 'clicked', {
Expand Down Expand Up @@ -313,7 +309,7 @@ export async function tokenCallbackHandler(uri: vscode.Uri): Promise<void> {
}

const authStatus = await authProvider.validateAndStoreCredentials(
{ serverEndpoint: endpoint, accessTokenOrHeaders: token, tokenSource: 'redirect' },
{ serverEndpoint: endpoint, credentials: { token, source: 'redirect' } },
'store-if-valid'
)
telemetryRecorder.recordEvent('cody.auth.fromCallback.web', 'succeeded', {
Expand Down Expand Up @@ -411,7 +407,7 @@ export async function validateCredentials(
clientConfig?: CodyClientConfig
): Promise<AuthStatus> {
// An access token is needed except for Cody Web, which uses cookies.
if (!config.auth.accessTokenOrHeaders && !clientCapabilities().isCodyWeb) {
if (!config.auth.credentials && !clientCapabilities().isCodyWeb) {
return { authenticated: false, endpoint: config.auth.serverEndpoint, pendingValidation: false }
}

Expand Down
4 changes: 2 additions & 2 deletions vscode/src/auth/token-receiver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const FIVE_MINUTES = 5 * 60 * 1000
// the user follow a redirect.
export function startTokenReceiver(
endpoint: string,
onNewToken: (credentials: Pick<AuthCredentials, 'serverEndpoint' | 'accessTokenOrHeaders'>) => void,
onNewToken: (credentials: Pick<AuthCredentials, 'serverEndpoint' | 'credentials'>) => void,
timeout = FIVE_MINUTES
): Promise<string> {
const endpointUrl = new URL(endpoint)
Expand Down Expand Up @@ -48,7 +48,7 @@ export function startTokenReceiver(
) {
onNewToken({
serverEndpoint: endpoint,
accessTokenOrHeaders: json.accessToken,
credentials: { token: json.accessToken, source: 'redirect' },
})

res.writeHead(200, headers)
Expand Down
2 changes: 1 addition & 1 deletion vscode/src/autoedits/adapters/cody-gateway.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ describe('CodyGatewayAdapter', () => {
mockResolvedConfig({
configuration: {},
auth: {
accessTokenOrHeaders: 'sgp_local_f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0',
credentials: { token: 'sgp_local_f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0' },
serverEndpoint: DOTCOM_URL.toString(),
},
})
Expand Down
4 changes: 2 additions & 2 deletions vscode/src/autoedits/adapters/cody-gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ export class CodyGatewayAdapter implements AutoeditsModelAdapter {
const resolvedConfig = await currentResolvedConfig()
// TODO (pkukielka): Check if fastpath should support custom auth providers and how
const accessToken =
typeof resolvedConfig.auth.accessTokenOrHeaders === 'string'
? resolvedConfig.auth.accessTokenOrHeaders
resolvedConfig.auth.credentials && 'token' in resolvedConfig.auth.credentials
? resolvedConfig.auth.credentials.token
: null
const fastPathAccessToken = dotcomTokenToGatewayToken(accessToken)
if (!fastPathAccessToken) {
Expand Down
2 changes: 1 addition & 1 deletion vscode/src/autoedits/autoedits-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ describe('AutoeditsProvider', () => {
mockResolvedConfig({
configuration: {},
auth: {
accessTokenOrHeaders: 'sgp_local_f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0',
credentials: { token: 'sgp_local_f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0' },
serverEndpoint: DOTCOM_URL.toString(),
},
})
Expand Down
Loading

0 comments on commit a0ad157

Please sign in to comment.