diff --git a/common/proxy/proxy.go b/common/proxy/proxy.go index 1c079f6bca..d13646dba8 100644 --- a/common/proxy/proxy.go +++ b/common/proxy/proxy.go @@ -181,3 +181,7 @@ func (p *Proxy) Implement(v common.RPCService) { func (p *Proxy) Get() common.RPCService { return p.rpc } + +func (p *Proxy) GetCallback() interface{} { + return p.callBack +} diff --git a/common/proxy/proxy_factory.go b/common/proxy/proxy_factory.go index 2567e0ee09..116cfe0669 100644 --- a/common/proxy/proxy_factory.go +++ b/common/proxy/proxy_factory.go @@ -24,6 +24,7 @@ import ( type ProxyFactory interface { GetProxy(invoker protocol.Invoker, url *common.URL) *Proxy + GetAsyncProxy(invoker protocol.Invoker, callBack interface{}, url *common.URL) *Proxy GetInvoker(url common.URL) protocol.Invoker } diff --git a/common/proxy/proxy_factory/default.go b/common/proxy/proxy_factory/default.go index bafba60b40..06824fdc1e 100644 --- a/common/proxy/proxy_factory/default.go +++ b/common/proxy/proxy_factory/default.go @@ -55,11 +55,16 @@ func NewDefaultProxyFactory(options ...proxy.Option) proxy.ProxyFactory { return &DefaultProxyFactory{} } func (factory *DefaultProxyFactory) GetProxy(invoker protocol.Invoker, url *common.URL) *proxy.Proxy { + return factory.GetAsyncProxy(invoker, nil, url) +} + +func (factory *DefaultProxyFactory) GetAsyncProxy(invoker protocol.Invoker, callBack interface{}, url *common.URL) *proxy.Proxy { //create proxy attachments := map[string]string{} attachments[constant.ASYNC_KEY] = url.GetParam(constant.ASYNC_KEY, "false") - return proxy.NewProxy(invoker, nil, attachments) + return proxy.NewProxy(invoker, callBack, attachments) } + func (factory *DefaultProxyFactory) GetInvoker(url common.URL) protocol.Invoker { return &ProxyInvoker{ BaseInvoker: *protocol.NewBaseInvoker(url), diff --git a/common/proxy/proxy_factory/default_test.go b/common/proxy/proxy_factory/default_test.go index b6a6b675ba..7159b4b00e 100644 --- a/common/proxy/proxy_factory/default_test.go +++ b/common/proxy/proxy_factory/default_test.go @@ -18,6 +18,7 @@ package proxy_factory import ( + "fmt" "testing" ) @@ -37,6 +38,21 @@ func Test_GetProxy(t *testing.T) { assert.NotNil(t, proxy) } +type TestAsync struct { +} + +func (u *TestAsync) CallBack(res common.CallbackResponse) { + fmt.Println("CallBack res:", res) +} + +func Test_GetAsyncProxy(t *testing.T) { + proxyFactory := NewDefaultProxyFactory() + url := common.NewURLWithOptions() + async := &TestAsync{} + proxy := proxyFactory.GetAsyncProxy(protocol.NewBaseInvoker(*url), async.CallBack, url) + assert.NotNil(t, proxy) +} + func Test_GetInvoker(t *testing.T) { proxyFactory := NewDefaultProxyFactory() url := common.NewURLWithOptions() diff --git a/common/rpc_service.go b/common/rpc_service.go index 4741a6fa3c..4c9f083dd0 100644 --- a/common/rpc_service.go +++ b/common/rpc_service.go @@ -39,6 +39,18 @@ type RPCService interface { Reference() string // rpc service id or reference id } +//AsyncCallbackService callback interface for async +type AsyncCallbackService interface { + CallBack(response CallbackResponse) // callback +} + +//CallbackResponse for different protocol +type CallbackResponse interface { +} + +//AsyncCallback async callback method +type AsyncCallback func(response CallbackResponse) + // for lowercase func // func MethodMapper() map[string][string] { // return map[string][string]{} diff --git a/config/reference_config.go b/config/reference_config.go index 8703c459ba..6b34f55359 100644 --- a/config/reference_config.go +++ b/config/reference_config.go @@ -55,7 +55,7 @@ type ReferenceConfig struct { Group string `yaml:"group" json:"group,omitempty" property:"group"` Version string `yaml:"version" json:"version,omitempty" property:"version"` Methods []*MethodConfig `yaml:"methods" json:"methods,omitempty" property:"methods"` - async bool `yaml:"async" json:"async,omitempty" property:"async"` + Async bool `yaml:"async" json:"async,omitempty" property:"async"` Params map[string]string `yaml:"params" json:"params,omitempty" property:"params"` invoker protocol.Invoker urls []*common.URL @@ -141,7 +141,12 @@ func (refconfig *ReferenceConfig) Refer() { } //create proxy - refconfig.pxy = extension.GetProxyFactory(consumerConfig.ProxyFactory).GetProxy(refconfig.invoker, url) + if refconfig.Async { + callback := GetCallback(refconfig.id) + refconfig.pxy = extension.GetProxyFactory(consumerConfig.ProxyFactory).GetAsyncProxy(refconfig.invoker, callback, url) + } else { + refconfig.pxy = extension.GetProxyFactory(consumerConfig.ProxyFactory).GetProxy(refconfig.invoker, url) + } } // @v is service provider implemented RPCService @@ -169,7 +174,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values { urlMap.Set(constant.GENERIC_KEY, strconv.FormatBool(refconfig.Generic)) urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) //getty invoke async or sync - urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.async)) + urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.Async)) //application info urlMap.Set(constant.APPLICATION_KEY, consumerConfig.ApplicationConfig.Name) diff --git a/config/reference_config_test.go b/config/reference_config_test.go index a81dbf06ce..a7af925cab 100644 --- a/config/reference_config_test.go +++ b/config/reference_config_test.go @@ -81,6 +81,7 @@ func doInitConsumer() { }, References: map[string]*ReferenceConfig{ "MockService": { + id: "MockProvider", Params: map[string]string{ "serviceid": "soa.mock", "forks": "5", @@ -110,6 +111,26 @@ func doInitConsumer() { } } +var mockProvider = new(MockProvider) + +type MockProvider struct { +} + +func (m *MockProvider) Reference() string { + return "MockProvider" +} + +func (m *MockProvider) CallBack(res common.CallbackResponse) { +} + +func doInitConsumerAsync() { + doInitConsumer() + SetConsumerService(mockProvider) + for _, v := range consumerConfig.References { + v.Async = true + } +} + func doInitConsumerWithSingleRegistry() { consumerConfig = &ConsumerConfig{ ApplicationConfig: &ApplicationConfig{ @@ -181,6 +202,22 @@ func Test_Refer(t *testing.T) { } consumerConfig = nil } + +func Test_ReferAsync(t *testing.T) { + doInitConsumerAsync() + extension.SetProtocol("registry", GetProtocol) + extension.SetCluster("registryAware", cluster_impl.NewRegistryAwareCluster) + + for _, reference := range consumerConfig.References { + reference.Refer() + assert.Equal(t, "soa.mock", reference.Params["serviceid"]) + assert.NotNil(t, reference.invoker) + assert.NotNil(t, reference.pxy) + assert.NotNil(t, reference.pxy.GetCallback()) + } + consumerConfig = nil +} + func Test_ReferP2P(t *testing.T) { doInitConsumer() extension.SetProtocol("dubbo", GetProtocol) diff --git a/config/service.go b/config/service.go index 2bceac4a8c..f1b51790ca 100644 --- a/config/service.go +++ b/config/service.go @@ -43,3 +43,11 @@ func GetConsumerService(name string) common.RPCService { func GetProviderService(name string) common.RPCService { return proServices[name] } + +func GetCallback(name string) func(response common.CallbackResponse) { + service := GetConsumerService(name) + if sv, ok := service.(common.AsyncCallbackService); ok { + return sv.CallBack + } + return nil +} diff --git a/protocol/dubbo/client.go b/protocol/dubbo/client.go index ba74d86c0c..1365838f3b 100644 --- a/protocol/dubbo/client.go +++ b/protocol/dubbo/client.go @@ -113,7 +113,9 @@ type Options struct { RequestTimeout time.Duration } -type CallResponse struct { +//AsyncCallbackResponse async response for dubbo +type AsyncCallbackResponse struct { + common.CallbackResponse Opts Options Cause error Start time.Time // invoke(call) start time == write start time @@ -121,8 +123,6 @@ type CallResponse struct { Reply interface{} } -type AsyncCallback func(response CallResponse) - type Client struct { opts Options conf ClientConfig @@ -199,12 +199,12 @@ func (c *Client) Call(request *Request, response *Response) error { return perrors.WithStack(c.call(ct, request, response, nil)) } -func (c *Client) AsyncCall(request *Request, callback AsyncCallback, response *Response) error { +func (c *Client) AsyncCall(request *Request, callback common.AsyncCallback, response *Response) error { return perrors.WithStack(c.call(CT_TwoWay, request, response, callback)) } -func (c *Client) call(ct CallType, request *Request, response *Response, callback AsyncCallback) error { +func (c *Client) call(ct CallType, request *Request, response *Response, callback common.AsyncCallback) error { p := &DubboPackage{} p.Service.Path = strings.TrimPrefix(request.svcUrl.Path, "/") diff --git a/protocol/dubbo/client_test.go b/protocol/dubbo/client_test.go index eb1f15c862..3f8a8ee98c 100644 --- a/protocol/dubbo/client_test.go +++ b/protocol/dubbo/client_test.go @@ -144,8 +144,9 @@ func TestClient_AsyncCall(t *testing.T) { user := &User{} lock := sync.Mutex{} lock.Lock() - err := c.AsyncCall(NewRequest("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil), func(response CallResponse) { - assert.Equal(t, User{Id: "1", Name: "username"}, *response.Reply.(*Response).reply.(*User)) + err := c.AsyncCall(NewRequest("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil), func(response common.CallbackResponse) { + r := response.(AsyncCallbackResponse) + assert.Equal(t, User{Id: "1", Name: "username"}, *r.Reply.(*Response).reply.(*User)) lock.Unlock() }, NewResponse(user, nil)) assert.NoError(t, err) diff --git a/protocol/dubbo/codec.go b/protocol/dubbo/codec.go index a878ffd91e..758363117f 100644 --- a/protocol/dubbo/codec.go +++ b/protocol/dubbo/codec.go @@ -26,6 +26,7 @@ import ( import ( "github.com/apache/dubbo-go-hessian2" + "github.com/apache/dubbo-go/common" perrors "github.com/pkg/errors" ) @@ -109,7 +110,7 @@ type PendingResponse struct { err error start time.Time readStart time.Time - callback AsyncCallback + callback common.AsyncCallback response *Response done chan struct{} } @@ -122,8 +123,8 @@ func NewPendingResponse() *PendingResponse { } } -func (r PendingResponse) GetCallResponse() CallResponse { - return CallResponse{ +func (r PendingResponse) GetCallResponse() common.CallbackResponse { + return AsyncCallbackResponse{ Cause: r.err, Start: r.start, ReadStart: r.readStart, diff --git a/protocol/dubbo/dubbo_invoker.go b/protocol/dubbo/dubbo_invoker.go index bc321a97a4..da12126103 100644 --- a/protocol/dubbo/dubbo_invoker.go +++ b/protocol/dubbo/dubbo_invoker.go @@ -75,7 +75,7 @@ func (di *DubboInvoker) Invoke(invocation protocol.Invocation) protocol.Result { } response := NewResponse(inv.Reply(), nil) if async { - if callBack, ok := inv.CallBack().(func(response CallResponse)); ok { + if callBack, ok := inv.CallBack().(func(response common.CallbackResponse)); ok { result.Err = di.client.AsyncCall(NewRequest(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments()), callBack, response) } else { result.Err = di.client.CallOneway(NewRequest(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments())) diff --git a/protocol/dubbo/dubbo_invoker_test.go b/protocol/dubbo/dubbo_invoker_test.go index 0a765356f7..7d60090e2d 100644 --- a/protocol/dubbo/dubbo_invoker_test.go +++ b/protocol/dubbo/dubbo_invoker_test.go @@ -28,6 +28,7 @@ import ( ) import ( + "github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/protocol/invocation" ) @@ -65,8 +66,9 @@ func TestDubboInvoker_Invoke(t *testing.T) { // AsyncCall lock := sync.Mutex{} lock.Lock() - inv.SetCallBack(func(response CallResponse) { - assert.Equal(t, User{Id: "1", Name: "username"}, *response.Reply.(*Response).reply.(*User)) + inv.SetCallBack(func(response common.CallbackResponse) { + r := response.(AsyncCallbackResponse) + assert.Equal(t, User{Id: "1", Name: "username"}, *r.Reply.(*Response).reply.(*User)) lock.Unlock() }) res = invoker.Invoke(inv)