Skip to content

Commit

Permalink
Fix race condition in DefaultBroadcasterFactory
Browse files Browse the repository at this point in the history
Fix race condition in DefaultBroadcasterFactory
  • Loading branch information
Jason Burgess committed Dec 5, 2011
1 parent 4dc491b commit 592ed26
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public abstract class BroadcasterFactory {
*
* @param b a {@link Broadcaster}
* @return false if wasn't present, or {@link Broadcaster}
* @oaram id the {@link Broadcaster's ID}
* @param id the {@link Broadcaster's ID}
*/
abstract public boolean remove(Broadcaster b, Object id);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@


import org.atmosphere.di.InjectorProvider;
import org.atmosphere.util.AbstractBroadcasterProxy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.InvocationTargetException;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
Expand All @@ -70,9 +68,9 @@ public class DefaultBroadcasterFactory extends BroadcasterFactory {
private static final Logger logger = LoggerFactory.getLogger(DefaultBroadcasterFactory.class);

private final ConcurrentHashMap<Object, Broadcaster> store = new ConcurrentHashMap<Object, Broadcaster>();

private final Class<? extends Broadcaster> clazz;

private BroadcasterLifeCyclePolicy policy =
new BroadcasterLifeCyclePolicy.Builder().policy(NEVER).build();

Expand Down Expand Up @@ -129,46 +127,33 @@ public final Broadcaster get(Object id) {
*/
public final Broadcaster get(Class<? extends Broadcaster> c, Object id) {

if (id == null) throw new NullPointerException("id is null");
if (c == null) throw new NullPointerException("Class is null");
if (id == null) {
throw new NullPointerException("id is null");
}
if (c == null) {
throw new NullPointerException("Class is null");
}

if (getBroadcaster(id) != null)
if (store.containsKey(id)) {
throw new IllegalStateException("Broadcaster already existing " + id + ". Use BroadcasterFactory.lookup instead");
}

Broadcaster b = null;
synchronized (id) {

// If two thread comes here at the same time, the second ID will erase the
if (store.get(id) != null){
return store.get(id);
}
return lookup(c, id, true);
}

try {
b = c.getConstructor(String.class, AtmosphereServlet.AtmosphereConfig.class).newInstance(id.toString(), config);
} catch (Throwable t) {
throw new BroadcasterCreationException(t);
}
private Broadcaster createBroadcaster(Class<? extends Broadcaster> c, Object id) throws BroadcasterCreationException {
try {
Broadcaster b = c.getConstructor(String.class, AtmosphereServlet.AtmosphereConfig.class).newInstance(id.toString(), config);
InjectorProvider.getInjector().inject(b);
b.setBroadcasterConfig(new BroadcasterConfig(AtmosphereServlet.broadcasterFilters, config));
b.setBroadcasterLifeCyclePolicy(policy);

if (DefaultBroadcaster.class.isAssignableFrom(clazz)) {
DefaultBroadcaster.class.cast(b).start();
}
store.put(id, b);
logger.debug("Added Broadcaster {} . Factory size: {}", id, store.size());
return b;
} catch (Throwable t) {
throw new BroadcasterCreationException(t);
}
return b;
}

/**
* Return a {@link Broadcaster} based on its name.
*
* @param name The unique ID
* @return a {@link Broadcaster}, or null
*/
private Broadcaster getBroadcaster(Object name) {
return store.get(name);
}

/**
Expand All @@ -182,9 +167,8 @@ public boolean add(Broadcaster b, Object id) {
* {@inheritDoc}
*/
public boolean remove(Broadcaster b, Object id) {
boolean removed = (store.get(b.getID()) == b);
boolean removed = store.remove(id, b);
if (removed) {
store.remove(id, b);
logger.debug("Removing Broadcaster {} which internal reference is {} ", id, b.getID());
}
return removed;
Expand Down Expand Up @@ -216,10 +200,10 @@ public final Broadcaster lookup(Object id, boolean createIfNull) {
*/
@Override
public Broadcaster lookup(Class<? extends Broadcaster> c, Object id, boolean createIfNull) {
Broadcaster b = getBroadcaster(id);
Broadcaster b = store.get(id);
if (b != null && !c.isAssignableFrom(b.getClass())) {
String msg = "Invalid lookup class " + c.getName() + ". Cached class is: " + b.getClass().getName();
logger.debug("{}", msg);
logger.debug(msg);
throw new IllegalStateException(msg);
}

Expand All @@ -228,7 +212,10 @@ public Broadcaster lookup(Class<? extends Broadcaster> c, Object id, boolean cre
logger.debug("Removing destroyed Broadcaster {}", b.getID());
store.remove(b.getID(), b);
}
b = get(c, id);
if (store.putIfAbsent(id, createBroadcaster(c, id)) == null) {
logger.debug("Added Broadcaster {} . Factory size: {}", id, store.size());
}
b = store.get(id);
}

return b;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package org.atmosphere.cpr;

import org.atmosphere.cpr.AtmosphereServlet.AtmosphereConfig;

import org.atmosphere.util.SimpleBroadcaster;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import static org.mockito.Mockito.*;

/**
*
* @author jburgess
*/
public class DefaultBroadcasterFactoryTest {

private AtmosphereConfig config;
private DefaultBroadcasterFactory factory;

@BeforeMethod
public void setUp() throws Exception {
config = mock(AtmosphereConfig.class);
factory = new DefaultBroadcasterFactory(DefaultBroadcaster.class, "NEVER", config);
}

@Test
public void testGet_0args() {
Broadcaster result = factory.get();
assert result != null;
assert result instanceof DefaultBroadcaster;
}

@Test
public void testGet_Object() {
String id = "id";
Broadcaster result = factory.get(id);
assert result != null;
assert result instanceof DefaultBroadcaster;
assert id.equals(result.getID());
}

@Test(expectedExceptions = IllegalStateException.class)
public void testGet_Object_Twice() {
String id = "id";
factory.get(id);
factory.get(id);
}

@Test
public void testAdd() {
String id = "id";
String id2 = "foo";
Broadcaster b = factory.get(id);
assert factory.add(b, id) == false;
assert factory.lookup(id) != null;
assert factory.add(b, id2) == true;
assert factory.lookup(id2) != null;
}

@Test
public void testRemove() {
String id = "id";
String id2 = "foo";
Broadcaster b = factory.get(id);
Broadcaster b2 = factory.get(id2);
assert factory.remove(b, id2) == false;
assert factory.remove(b2, id) == false;
assert factory.remove(b, id) == true;
assert factory.lookup(id) == null;
}

@Test
public void testLookup_Class_Object() {
String id = "id";
String id2 = "foo";
Broadcaster b = factory.get(id);
assert factory.lookup(DefaultBroadcaster.class, id) != null;
assert factory.lookup(DefaultBroadcaster.class, id2) == null;
}

@Test(expectedExceptions = IllegalStateException.class)
public void testLookup_Class_Object_BadClass() {
String id = "id";
factory.get(id);
factory.lookup(SimpleBroadcaster.class, id);
}

@Test
public void testLookup_Object() {
String id = "id";
String id2 = "foo";
factory.get(id);
assert factory.lookup(id) != null;
assert factory.lookup(id2) == null;
}

@Test
public void testLookup_Object_boolean() {
String id = "id";
assert factory.lookup(id, false) == null;
Broadcaster b = factory.lookup(id, true);
assert b != null;
assert id.equals(b.getID());
}

@Test
public void testLookup_Class_Object_boolean() {
String id = "id";
assert factory.lookup(DefaultBroadcaster.class, id, false) == null;
Broadcaster b = factory.lookup(DefaultBroadcaster.class, id, true);
assert b != null;
assert b instanceof DefaultBroadcaster;
assert id.equals(b.getID());
}
}

0 comments on commit 592ed26

Please sign in to comment.