Skip to content

Commit

Permalink
feat(repo): add geth-rpc-gateway (#18382)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaodino authored Nov 25, 2024
1 parent 61994ff commit d998291
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 0 deletions.
1 change: 1 addition & 0 deletions packages/geth-rpc-gateway/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
geth-rpc-gateway
73 changes: 73 additions & 0 deletions packages/geth-rpc-gateway/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# geth-rpc-gateway

```sh
go build -o geth-rpc-gateway .
```

Build for Linux

```sh
GOOS=linux GOARCH=amd64 go build -o geth-rpc-gateway .
```

## How to test

### Example code

```
curl --location --request POST 'https://rpc.internal.taiko.xyz/' \
--header 'Content-Type: application/json' \
--data-raw '{
"jsonrpc": "2.0",
"id": 4,
"method": "eth_blockNumber",
"params": [
]
}'
```

```
'use strict'
const { ethers } = require('ethers');
// const provider = new ethers.providers.JsonRpcProvider("https://l1rpc.mainnet.taiko.xyz");
const provider = new ethers.providers.WebSocketProvider("wss://ws.internal.taiko.xyz");
async function main() {
console.log(await provider.getBlock("latest"));
process.exit(0);
}
main().catch(console.error);
```

```
curl -i -X POST \
-H "Content-Type:application/json" \
-d \
'[
{"id":92471,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0x832ef3260c46288e9596d0ddb61c4c9d5965f7da8d076483d08ac2d4265a69b8"]},
{"id":91112,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xbaac413b4cbf6a2f19ef3da2f103f8298042cbba2820fba020a322f9602f8e58"]},
{"id":48734,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0x7c649da4df9bea4552c05d4710a1ffb16fed5be81c11912aceb568a8212213d6"]},
{"id":45180,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xb23f58cb6b5155f792fa96c63962c44efba5280a4eed76400eca477e04c7456c"]},
{"id":95408,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xab7e06e9666ba0c270fe06e45fe604316049232c4479f975db0a0ec16b4f9b38"]},
{"id":193,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xd453488f5e14cfb3ac1057e42c1e3eb74420759fe0331894c59f3108e1c813b0"]}
]' \
'https://rpc.hekla.taiko.xyz/'
```

```
curl -i -X POST \
-H "Content-Type:application/json" \
-d \
'[
{"id":92471,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0x832ef3260c46288e9596d0ddb61c4c9d5965f7da8d076483d08ac2d4265a69b8"]},
{"id":91112,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xbaac413b4cbf6a2f19ef3da2f103f8298042cbba2820fba020a322f9602f8e58"]},
{"id":48734,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0x7c649da4df9bea4552c05d4710a1ffb16fed5be81c11912aceb568a8212213d6"]},
{"id":45180,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xb23f58cb6b5155f792fa96c63962c44efba5280a4eed76400eca477e04c7456c"]},
{"id":95408,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xab7e06e9666ba0c270fe06e45fe604316049232c4479f975db0a0ec16b4f9b38"]},
{"id":193,"jsonrpc":"2.0","method":"eth_getTransactionReceipt","params":["0xd453488f5e14cfb3ac1057e42c1e3eb74420759fe0331894c59f3108e1c813b0"]}
]' \
'http://localhost:8080'
```
314 changes: 314 additions & 0 deletions packages/geth-rpc-gateway/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
package main

import (
"bytes"
"encoding/json"
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
"strings"

"github.com/gorilla/websocket"
)

type JSONRPCRequest struct {
Method string `json:"method"`
}

var (
upgrader = websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
methodsUsingPrimary map[string]bool
primaryURL *url.URL
secondaryURL *url.URL
webSocketURL *url.URL
enableDebugEndpoints bool
)

func main() {
// Load the target URLs from environment variables
var err error
primaryURL, err = url.Parse(os.Getenv("TARGET_URL_PRIMARY"))
if err != nil || primaryURL == nil {
log.Fatalf("Failed to parse primary target URL: %v", err)
}
secondaryURL, err = url.Parse(os.Getenv("TARGET_URL_SECONDARY"))
if err != nil || secondaryURL == nil {
log.Fatalf("Failed to parse secondary target URL: %v", err)
}
webSocketURL, err = url.Parse(os.Getenv("WEBSOCKET_TARGET_URL"))
if err != nil || webSocketURL == nil {
log.Fatalf("Failed to parse WebSocket target URL: %v", err)
}

methodsUsingPrimary = parsePrimaryMethods(os.Getenv("PRIMARY_METHODS"))
enableDebugEndpoints = os.Getenv("ENABLE_DEBUG_ENDPOINTS") == "true"

http.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
log.Printf("/healthz Received request: Method=%s, Path=%s", r.Method, r.URL.Path)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})

// Determine if server should handle WebSocket or RPC based on environment variable
if os.Getenv("IS_WEBSOCKET") == "true" {
log.Println("Starting in WebSocket mode")
http.HandleFunc("/", rootWebSocketHandler) // WebSocket handler without CORS
} else {
log.Println("Starting in RPC mode")
http.Handle("/", enableCORS(http.HandlerFunc(rootHandler))) // HTTP handler with CORS middleware
}

log.Fatal(http.ListenAndServe(":8080", nil))
}

// WebSocket handler for `/` path when in WebSocket mode
func rootWebSocketHandler(w http.ResponseWriter, r *http.Request) {
// Check for WebSocket Upgrade
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
log.Printf("WebSocket connection initiated...")
handleWebSocket(w, r, webSocketURL)
return
}

w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}

// CORS middleware to enable CORS headers
func enableCORS(next http.Handler) http.Handler {
log.Printf("enableCORS")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("CORS middleware invoked for %s %s", r.Method, r.URL.Path)

// Get the Origin header from the request
origin := r.Header.Get("Origin")

// Set Access-Control-Allow-Origin only if the request has an Origin header
if origin != "" {
log.Printf("CORS middleware invoked for origin %s", origin)
w.Header().Del("Access-Control-Allow-Origin") // Clear any existing header
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin") // Ensure caching based on origin
}

w.Header().Set("Access-Control-Allow-Methods", r.Method)
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")

w.WriteHeader(http.StatusOK)

next.ServeHTTP(w, r)
})
}

func rootHandler(w http.ResponseWriter, r *http.Request) {
log.Printf("rootHandler...")

// Check for WebSocket Upgrade
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
handleWebSocket(w, r, webSocketURL)
return
}

// Handle HTTP requests

bodyBytes, err := ioutil.ReadAll(r.Body)

log.Printf("Handle HTTP requests...")
if err != nil {
log.Printf("Error")

http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()

if len(bodyBytes) == 0 {
w.Write([]byte("OK"))
return
}

// Determine the target URL and extract methods
usePrimaryURL, methods := shouldUsePrimaryURL(bodyBytes)
var targetURL *url.URL
if usePrimaryURL {
targetURL = primaryURL
log.Printf("HTTP request hitting TARGET_URL_PRIMARY")
} else {
targetURL = secondaryURL
log.Printf("HTTP request hitting TARGET_URL_SECONDARY")
}

// Check each method for debug restrictions
for _, method := range methods {
if enableDebugEndpoints && isDebugMethod(method) && method != "debug_traceBlock" && method != "debug_traceBlockByNumber" {
http.Error(w, "Unsupported method", http.StatusBadRequest)
return
}
}

// Forward the original JSON payload as-is to the target URL
forwardRequest(w, r, targetURL, bodyBytes)
}

// Function to forward the request to the target URL
func forwardRequest(w http.ResponseWriter, r *http.Request, targetURL *url.URL, bodyBytes []byte) {

proxyReq, err := http.NewRequest(r.Method, targetURL.String()+r.RequestURI, ioutil.NopCloser(bytes.NewReader(bodyBytes)))
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}

// Copy headers from the original request, excluding Accept-Encoding
for name, values := range r.Header {
log.Printf("proxy req name %s, value %s", name, values)
if name == "Accept-Encoding" {
continue
}
for _, value := range values {
proxyReq.Header.Add(name, value)
}
}

// Send the request to the target URL
resp, err := http.DefaultClient.Do(proxyReq)
if err != nil {
http.Error(w, "Failed to reach target server", http.StatusInternalServerError)
return
}
defer resp.Body.Close()

// Log headers before setting them to diagnose any discrepancies
log.Printf("Received Content-Type from upstream: %s", resp.Header.Get("Content-Type"))

// Prepare to copy headers from the response
for name, values := range resp.Header {
log.Printf("response name %s, value %s", name, values)
switch name {
case "Content-Length", "Transfer-Encoding", "Connection":
// Skip these headers
continue
default:
for _, value := range values {
w.Header().Add(name, value)
}
}
}

// Explicitly set Content-Type if it's present in the response
if contentType := resp.Header.Get("Content-Type"); contentType != "" {
w.Header().Set("Content-Type", contentType)
} else {
w.Header().Set("Content-Type", "application/json") // default if not provided
}
log.Printf("Set Content-Type header: %s", w.Header().Get("Content-Type"))

// Read the response body into a buffer to set Content-Type explicitly
var buf bytes.Buffer
if _, err := io.Copy(&buf, resp.Body); err != nil {
log.Printf("Error reading response body into buffer: %v", err)
http.Error(w, "Failed to read response body", http.StatusInternalServerError)
return
}

// Write status code and ensure Content-Type is set
w.WriteHeader(resp.StatusCode)
log.Printf("Response status code: %d", resp.StatusCode)
log.Printf("Final Content-Type header: %s", w.Header().Get("Content-Type"))

// Write the buffered body to the response
if _, err := io.Copy(w, &buf); err != nil {
log.Printf("Error copying buffer to response: %v", err)
}
}

func isDebugMethod(method string) bool {
return len(method) >= 6 && method[:6] == "debug_" && method != "debug_traceBlock"
}

// Parses the PRIMARY_METHODS environment variable and returns a map of methods using the primary URL
func parsePrimaryMethods(methods string) map[string]bool {
methodMap := make(map[string]bool)
for _, method := range strings.Split(methods, ",") {
method = strings.TrimSpace(method)
if method != "" {
methodMap[method] = true
}
}
return methodMap
}

// Checks if any method should use the primary URL and returns all methods
func shouldUsePrimaryURL(bodyBytes []byte) (bool, []string) {
var singleRequest JSONRPCRequest
var multipleRequests []JSONRPCRequest
methods := []string{}

// Try unmarshalling as a single request
if err := json.Unmarshal(bodyBytes, &singleRequest); err == nil {
methods = append(methods, singleRequest.Method)
return methodsUsingPrimary[singleRequest.Method], methods
}

// Try unmarshalling as an array of requests
if err := json.Unmarshal(bodyBytes, &multipleRequests); err == nil {
usePrimary := false
for _, req := range multipleRequests {
methods = append(methods, req.Method)
if methodsUsingPrimary[req.Method] {
usePrimary = true
}
}
return usePrimary, methods
}

log.Printf("Invalid JSON in request body: unable to parse as single or multiple requests")
return false, methods // Default to secondary URL if JSON is invalid
}

func handleWebSocket(w http.ResponseWriter, r *http.Request, targetURL *url.URL) {
clientConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Failed to upgrade connection: %v", err)
http.Error(w, "Failed to upgrade to WebSocket", http.StatusInternalServerError)
return
}
defer clientConn.Close()

targetConn, _, err := websocket.DefaultDialer.Dial(targetURL.String(), nil)
if err != nil {
log.Printf("Failed to connect to target WebSocket server: %v", err)
http.Error(w, "Failed to connect to target WebSocket server", http.StatusInternalServerError)
return
}
defer targetConn.Close()

go func() {
for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
log.Printf("Error reading message from client: %v", err)
return
}
if err := targetConn.WriteMessage(messageType, message); err != nil {
log.Printf("Error writing message to target server: %v", err)
return
}
}
}()

for {
messageType, message, err := targetConn.ReadMessage()
if err != nil {
log.Printf("Error reading message from target server: %v", err)
return
}
if err := clientConn.WriteMessage(messageType, message); err != nil {
log.Printf("Error writing message to client: %v", err)
return
}
}
}

0 comments on commit d998291

Please sign in to comment.