Skip to content

Commit

Permalink
Refactor NettyTransport#getLocalAddress()
Browse files Browse the repository at this point in the history
  • Loading branch information
Jochen Schalanda committed Jan 3, 2018
1 parent a16554b commit 6170419
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import io.netty.channel.EventLoopGroup;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.socket.DatagramChannelConfig;
import io.netty.channel.unix.UnixChannelOption;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.graylog2.inputs.transports.netty.DatagramChannelFactory;
import org.graylog2.inputs.transports.netty.DatagramPacketHandler;
import org.graylog2.inputs.transports.netty.NettyTransportType;
Expand All @@ -48,6 +50,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.net.SocketAddress;
import java.util.LinkedHashMap;
import java.util.concurrent.Callable;
Expand All @@ -56,7 +59,7 @@ public class UdpTransport extends NettyTransport {
private static final Logger LOG = LoggerFactory.getLogger(UdpTransport.class);

private final NettyTransportConfiguration nettyTransportConfiguration;

private final ChannelGroup channels;
private Bootstrap bootstrap;

@AssistedInject
Expand All @@ -67,6 +70,7 @@ public UdpTransport(@Assisted Configuration configuration,
LocalMetricRegistry localRegistry) {
super(configuration, eventLoopGroup, throughputCounter, localRegistry);
this.nettyTransportConfiguration = nettyTransportConfiguration;
this.channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
}

@VisibleForTesting
Expand Down Expand Up @@ -113,17 +117,17 @@ public void launch(final MessageInput input) throws MisfireException {
}
}


@Override
public void stop() {
super.stop();
if (channels != null) {
channels.close().syncUninterruptibly();
}
bootstrap = null;
}

/**
* Get the local socket address this transport is listening on after being launched.
*
* @return the listening address of this transport or {@code null} if the transport hasn't been launched yet.
*/
@Nullable
@Override
public SocketAddress getLocalAddress() {
if (channels != null) {
return channels.stream().findFirst().map(Channel::localAddress).orElse(null);
Expand All @@ -132,6 +136,7 @@ public SocketAddress getLocalAddress() {
return null;
}


@FactoryClass
public interface Factory extends Transport.Factory<UdpTransport> {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.socket.ServerSocketChannelConfig;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
Expand All @@ -55,10 +54,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import java.io.File;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
Expand All @@ -71,6 +72,7 @@
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Preconditions.checkState;

Expand Down Expand Up @@ -99,6 +101,7 @@ public abstract class AbstractTcpTransport extends NettyTransport {

protected final Configuration configuration;
private final NettyTransportConfiguration nettyTransportConfiguration;
private final AtomicReference<Channel> channelReference;

private final boolean tlsEnable;
private final String tlsKeyPassword;
Expand All @@ -119,6 +122,7 @@ public AbstractTcpTransport(
super(configuration, eventLoopGroup, throughputCounter, localRegistry);
this.configuration = configuration;
this.nettyTransportConfiguration = nettyTransportConfiguration;
this.channelReference = new AtomicReference<>();

this.tlsEnable = configuration.getBoolean(CK_TLS_ENABLE);
this.tlsCertFile = getTlsFile(configuration, CK_TLS_CERT_FILE);
Expand Down Expand Up @@ -172,13 +176,32 @@ public void launch(final MessageInput input) throws MisfireException {
bootstrap = getBootstrap(input);

bootstrap.bind(socketAddress)
.addListener(new InputLaunchListener(channels, input, getRecvBufferSize()))
.addListener(new InputLaunchListener(channelReference, input, getRecvBufferSize()))
.syncUninterruptibly();
} catch (Exception e) {
throw new MisfireException(e);
}
}

@Nullable
@Override
public SocketAddress getLocalAddress() {
final Channel channel = channelReference.get();
if (channel != null) {
return channel.localAddress();
}

return null;
}

@Override
public void stop() {
final Channel channel = channelReference.get();
if (channel != null) {
channel.close().syncUninterruptibly();
}
}

@Override
protected LinkedHashMap<String, Callable<? extends ChannelHandler>> getChildChannelHandlers(MessageInput input) {
final LinkedHashMap<String, Callable<? extends ChannelHandler>> handlers = new LinkedHashMap<>();
Expand Down Expand Up @@ -355,12 +378,12 @@ public ConfigurationRequest getRequestedConfiguration() {
}

private static class InputLaunchListener implements ChannelFutureListener {
private final ChannelGroup channels;
private final AtomicReference<Channel> channelReference;
private final MessageInput input;
private final int expectedRecvBufferSize;

public InputLaunchListener(ChannelGroup channels, MessageInput input, int expectedRecvBufferSize) {
this.channels = channels;
public InputLaunchListener(AtomicReference<Channel> channelReference, MessageInput input, int expectedRecvBufferSize) {
this.channelReference = channelReference;
this.input = input;
this.expectedRecvBufferSize = expectedRecvBufferSize;
}
Expand All @@ -369,7 +392,7 @@ public InputLaunchListener(ChannelGroup channels, MessageInput input, int expect
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
final Channel channel = future.channel();
channels.add(channel);
channelReference.set(channel);
LOG.debug("Started channel {}", channel);

final ServerSocketChannelConfig channelConfig = (ServerSocketChannelConfig) channel.config();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ public abstract class NettyTransport implements Transport {
protected final ThroughputCounter throughputCounter;
private final int recvBufferSize;

protected final ChannelGroup channels;

@Nullable
private CodecAggregator aggregator;

Expand All @@ -85,7 +83,6 @@ public NettyTransport(Configuration configuration,
: MessageInput.getDefaultRecvBufferSize();

this.eventLoopGroup = eventLoopGroup;
this.channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

this.localRegistry = localRegistry;
localRegistry.registerAll(MetricSets.of(throughputCounter.gauges()));
Expand All @@ -103,6 +100,15 @@ protected void initChannel(Channel ch) throws Exception {
};
}

/**
* Get the local socket address this transport is listening on after being launched.
*
* @return the listening address of this transport or {@code null} if the transport hasn't been launched yet.
*/
@VisibleForTesting
@Nullable
public abstract SocketAddress getLocalAddress();

@Override
public void setMessageAggregator(@Nullable CodecAggregator aggregator) {
this.aggregator = aggregator;
Expand Down Expand Up @@ -169,13 +175,6 @@ protected LinkedHashMap<String, Callable<? extends ChannelHandler>> getChildChan
return handlerList;
}

@Override
public void stop() {
if (channels != null) {
channels.close().syncUninterruptibly();
}
}

protected int getRecvBufferSize() {
return recvBufferSize;
}
Expand All @@ -185,21 +184,6 @@ public MetricSet getMetricSet() {
return localRegistry;
}

/**
* Get the local socket address this transport is listening on after being launched.
*
* @return the listening address of this transport or {@code null} if the transport hasn't been launched yet.
*/
@VisibleForTesting
@Nullable
SocketAddress getLocalAddress() {
if (channels != null) {
return channels.stream().findFirst().map(Channel::localAddress).orElse(null);
}

return null;
}

public static class Config implements Transport.Config {
@Override
public ConfigurationRequest getRequestedConfiguration() {
Expand Down

0 comments on commit 6170419

Please sign in to comment.