Skip to content

Commit

Permalink
Add tests for mutual auth
Browse files Browse the repository at this point in the history
  • Loading branch information
armstrongnate committed Jun 12, 2018
1 parent 2252397 commit 18af90f
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 37 deletions.
4 changes: 3 additions & 1 deletion Sources/CgRPC/shim/cgrpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ void cgrpc_channel_watch_connectivity_state(cgrpc_channel *channel,
cgrpc_server *cgrpc_server_create(const char *address);
cgrpc_server *cgrpc_server_create_secure(const char *address,
const char *private_key,
const char *cert_chain);
const char *cert_chain,
const char *root_certs,
int force_client_auth);
void cgrpc_server_stop(cgrpc_server *server);
void cgrpc_server_destroy(cgrpc_server *s);
void cgrpc_server_start(cgrpc_server *s);
Expand Down
8 changes: 5 additions & 3 deletions Sources/CgRPC/shim/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ cgrpc_server *cgrpc_server_create(const char *address) {

cgrpc_server *cgrpc_server_create_secure(const char *address,
const char *private_key,
const char *cert_chain) {
const char *cert_chain,
const char *root_certs,
int force_client_auth) {
cgrpc_server *server = (cgrpc_server *) malloc(sizeof (cgrpc_server));
server->server = grpc_server_create(NULL, NULL);
server->completion_queue = grpc_completion_queue_create_for_next(NULL);
Expand All @@ -44,10 +46,10 @@ cgrpc_server *cgrpc_server_create_secure(const char *address,
server_credentials.cert_chain = cert_chain;

grpc_server_credentials *credentials = grpc_ssl_server_credentials_create
(NULL,
(root_certs,
&server_credentials,
1,
0,
force_client_auth,
NULL);

// prepare the server to listen
Expand Down
8 changes: 4 additions & 4 deletions Sources/Examples/Echo/Generated/echo.grpc.swift
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ internal final class Echo_EchoServer: ServiceServer {
super.init(address: address)
}

internal init?(address: String, certificateURL: URL, keyURL: URL, provider: Echo_EchoProvider) {
internal init?(address: String, certificateURL: URL, keyURL: URL, rootCertsURL: URL? = nil, provider: Echo_EchoProvider) {
self.provider = provider
super.init(address: address, certificateURL: certificateURL, keyURL: keyURL)
super.init(address: address, certificateURL: certificateURL, keyURL: keyURL, rootCertsURL: rootCertsURL)
}

internal init?(address: String, certificateString: String, keyString: String, provider: Echo_EchoProvider) {
internal init?(address: String, certificateString: String, keyString: String, rootCerts: String? = nil, provider: Echo_EchoProvider) {
self.provider = provider
super.init(address: address, certificateString: certificateString, keyString: keyString)
super.init(address: address, certificateString: certificateString, keyString: keyString, rootCerts: rootCerts)
}

/// Determines and calls the appropriate request handler, depending on the request's method.
Expand Down
4 changes: 2 additions & 2 deletions Sources/SwiftGRPC/Core/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public class Server {
/// - Parameter address: the address where the server will listen
/// - Parameter key: the private key for the server's certificates
/// - Parameter certs: the server's certificates
public init(address: String, key: String, certs: String) {
underlyingServer = cgrpc_server_create_secure(address, key, certs)
public init(address: String, key: String, certs: String, rootCerts: String? = nil) {
underlyingServer = cgrpc_server_create_secure(address, key, certs, rootCerts, rootCerts == nil ? 0 : 1)
completionQueue = CompletionQueue(
underlyingCompletionQueue: cgrpc_server_get_completion_queue(underlyingServer), name: "Server " + address)
}
Expand Down
15 changes: 11 additions & 4 deletions Sources/SwiftGRPC/Runtime/ServiceServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,27 @@ open class ServiceServer {
}

/// Create a server that accepts secure connections.
public init(address: String, certificateString: String, keyString: String) {
public init(address: String, certificateString: String, keyString: String, rootCerts: String? = nil) {
gRPC.initialize()
self.address = address
server = Server(address: address, key: keyString, certs: certificateString)
server = Server(address: address, key: keyString, certs: certificateString, rootCerts: rootCerts)
}

/// Create a server that accepts secure connections.
public init?(address: String, certificateURL: URL, keyURL: URL) {
public init?(address: String, certificateURL: URL, keyURL: URL, rootCertsURL: URL?) {
guard let certificate = try? String(contentsOf: certificateURL, encoding: .utf8),
let key = try? String(contentsOf: keyURL, encoding: .utf8)
else { return nil }
var rootCerts: String?
if let rootCertsURL = rootCertsURL {
guard let rootCertsString = try? String(contentsOf: rootCertsURL, encoding: .utf8) else {
return nil
}
rootCerts = rootCertsString
}
gRPC.initialize()
self.address = address
server = Server(address: address, key: key, certs: certificate)
server = Server(address: address, key: key, certs: certificate, rootCerts: rootCerts)
}

public enum HandleMethodError: Error {
Expand Down
8 changes: 4 additions & 4 deletions Sources/protoc-gen-swiftgrpc/Generator-Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ extension Generator {
outdent()
println("}")
println()
println("\(access) init?(address: String, certificateURL: URL, keyURL: URL, provider: \(providerName)) {")
println("\(access) init?(address: String, certificateURL: URL, keyURL: URL, rootCertsURL: URL? = nil, provider: \(providerName)) {")
indent()
println("self.provider = provider")
println("super.init(address: address, certificateURL: certificateURL, keyURL: keyURL)")
println("super.init(address: address, certificateURL: certificateURL, keyURL: keyURL, rootCertsURL: rootCertsURL)")
outdent()
println("}")
println()
println("\(access) init?(address: String, certificateString: String, keyString: String, provider: \(providerName)) {")
println("\(access) init?(address: String, certificateString: String, keyString: String, rootCerts: String? = nil, provider: \(providerName)) {")
indent()
println("self.provider = provider")
println("super.init(address: address, certificateString: certificateString, keyString: keyString)")
println("super.init(address: address, certificateString: certificateString, keyString: keyString, rootCerts: rootCerts)")
outdent()
println("}")
println()
Expand Down
33 changes: 25 additions & 8 deletions Tests/SwiftGRPCTests/BasicEchoTestCase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,53 @@ extension Echo_EchoResponse {
}

class BasicEchoTestCase: XCTestCase {
enum Security {
case none
case ssl
case tlsMutualAuth
}

func makeProvider() -> Echo_EchoProvider { return EchoProvider() }

var provider: Echo_EchoProvider!
var server: Echo_EchoServer!
var client: Echo_EchoServiceClient!

var defaultTimeout: TimeInterval { return 1.0 }
var secure: Bool { return false }
var security: Security { return .none }
var address: String { return "localhost:5050" }

override func setUp() {
super.setUp()

provider = makeProvider()

if secure {
let certificateString = String(data: certificateForTests, encoding: .utf8)!

let certificateString = String(data: certificateForTests, encoding: .utf8)!
let keyString = String(data: keyForTests, encoding: .utf8)!
let rootCerts = String(data: trustCollectionCertificateForTests, encoding: .utf8)!
let clientCertificateString = String(data: clientCertificateForTests, encoding: .utf8)!
let clientKeyString = String(data: clientKeyForTests, encoding: .utf8)!

switch security {
case .ssl:
server = Echo_EchoServer(address: address,
certificateString: certificateString,
keyString: String(data: keyForTests, encoding: .utf8)!,
keyString: keyString,
provider: provider)
server.start()
client = Echo_EchoServiceClient(address: address, certificates: certificateString, arguments: [.sslTargetNameOverride("example.com")])
client = Echo_EchoServiceClient(address: address, certificates: rootCerts, arguments: [.sslTargetNameOverride("example.com")])
client.host = "example.com"
} else {
case .tlsMutualAuth:
server = Echo_EchoServer(address: address, certificateString: certificateString, keyString: keyString, rootCerts: rootCerts, provider: provider)
server.start()
client = Echo_EchoServiceClient(address: address, certificates: rootCerts, clientCertificates: clientCertificateString, clientKey: clientKeyString, arguments: [.sslTargetNameOverride("example.com")])
client.host = "example.com"
case .none:
server = Echo_EchoServer(address: address, provider: provider)
server.start()
client = Echo_EchoServiceClient(address: address, secure: false)
}

client.timeout = defaultTimeout
}

Expand Down
6 changes: 5 additions & 1 deletion Tests/SwiftGRPCTests/EchoTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class EchoTests: BasicEchoTestCase {
}

class EchoTestsSecure: EchoTests {
override var secure: Bool { return true }
override var security: Security { return .ssl }
}

class EchoTestsMutualAuth: EchoTests {
override var security: Security { return .tlsMutualAuth }
}

extension EchoTests {
Expand Down
2 changes: 1 addition & 1 deletion Tests/SwiftGRPCTests/GRPCTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func runClient(useSSL: Bool) throws {

if useSSL {
channel = Channel(address: address,
certificates: String(data: certificateForTests, encoding: .utf8)!,
certificates: String(data: trustCollectionCertificateForTests, encoding: .utf8)!,
arguments: [.sslTargetNameOverride(host)])
} else {
channel = Channel(address: address, secure: false)
Expand Down
Loading

0 comments on commit 18af90f

Please sign in to comment.