From 592ed265c4e641b2a37cf2a4ff51bf678c064899 Mon Sep 17 00:00:00 2001 From: Jason Burgess Date: Mon, 5 Dec 2011 13:33:31 -0700 Subject: [PATCH] Fix race condition in DefaultBroadcasterFactory Fix race condition in DefaultBroadcasterFactory --- .../atmosphere/cpr/BroadcasterFactory.java | 2 +- .../cpr/DefaultBroadcasterFactory.java | 63 ++++------ .../cpr/DefaultBroadcasterFactoryTest.java | 118 ++++++++++++++++++ 3 files changed, 144 insertions(+), 39 deletions(-) create mode 100644 modules/cpr/src/test/java/org/atmosphere/cpr/DefaultBroadcasterFactoryTest.java diff --git a/modules/cpr/src/main/java/org/atmosphere/cpr/BroadcasterFactory.java b/modules/cpr/src/main/java/org/atmosphere/cpr/BroadcasterFactory.java index a46d0f4663b..8ef6782f56a 100644 --- a/modules/cpr/src/main/java/org/atmosphere/cpr/BroadcasterFactory.java +++ b/modules/cpr/src/main/java/org/atmosphere/cpr/BroadcasterFactory.java @@ -92,7 +92,7 @@ public abstract class BroadcasterFactory { * * @param b a {@link Broadcaster} * @return false if wasn't present, or {@link Broadcaster} - * @oaram id the {@link Broadcaster's ID} + * @param id the {@link Broadcaster's ID} */ abstract public boolean remove(Broadcaster b, Object id); diff --git a/modules/cpr/src/main/java/org/atmosphere/cpr/DefaultBroadcasterFactory.java b/modules/cpr/src/main/java/org/atmosphere/cpr/DefaultBroadcasterFactory.java index 250272f62e5..5a93c243a36 100755 --- a/modules/cpr/src/main/java/org/atmosphere/cpr/DefaultBroadcasterFactory.java +++ b/modules/cpr/src/main/java/org/atmosphere/cpr/DefaultBroadcasterFactory.java @@ -40,11 +40,9 @@ import org.atmosphere.di.InjectorProvider; -import org.atmosphere.util.AbstractBroadcasterProxy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.lang.reflect.InvocationTargetException; import java.util.Collection; import java.util.Collections; import java.util.Enumeration; @@ -70,9 +68,9 @@ public class DefaultBroadcasterFactory extends BroadcasterFactory { private static final Logger logger = LoggerFactory.getLogger(DefaultBroadcasterFactory.class); private final ConcurrentHashMap store = new ConcurrentHashMap(); - + private final Class clazz; - + private BroadcasterLifeCyclePolicy policy = new BroadcasterLifeCyclePolicy.Builder().policy(NEVER).build(); @@ -129,46 +127,33 @@ public final Broadcaster get(Object id) { */ public final Broadcaster get(Class c, Object id) { - if (id == null) throw new NullPointerException("id is null"); - if (c == null) throw new NullPointerException("Class is null"); + if (id == null) { + throw new NullPointerException("id is null"); + } + if (c == null) { + throw new NullPointerException("Class is null"); + } - if (getBroadcaster(id) != null) + if (store.containsKey(id)) { throw new IllegalStateException("Broadcaster already existing " + id + ". Use BroadcasterFactory.lookup instead"); + } - Broadcaster b = null; - synchronized (id) { - - // If two thread comes here at the same time, the second ID will erase the - if (store.get(id) != null){ - return store.get(id); - } + return lookup(c, id, true); + } - try { - b = c.getConstructor(String.class, AtmosphereServlet.AtmosphereConfig.class).newInstance(id.toString(), config); - } catch (Throwable t) { - throw new BroadcasterCreationException(t); - } + private Broadcaster createBroadcaster(Class c, Object id) throws BroadcasterCreationException { + try { + Broadcaster b = c.getConstructor(String.class, AtmosphereServlet.AtmosphereConfig.class).newInstance(id.toString(), config); InjectorProvider.getInjector().inject(b); b.setBroadcasterConfig(new BroadcasterConfig(AtmosphereServlet.broadcasterFilters, config)); b.setBroadcasterLifeCyclePolicy(policy); - if (DefaultBroadcaster.class.isAssignableFrom(clazz)) { DefaultBroadcaster.class.cast(b).start(); } - store.put(id, b); - logger.debug("Added Broadcaster {} . Factory size: {}", id, store.size()); + return b; + } catch (Throwable t) { + throw new BroadcasterCreationException(t); } - return b; - } - - /** - * Return a {@link Broadcaster} based on its name. - * - * @param name The unique ID - * @return a {@link Broadcaster}, or null - */ - private Broadcaster getBroadcaster(Object name) { - return store.get(name); } /** @@ -182,9 +167,8 @@ public boolean add(Broadcaster b, Object id) { * {@inheritDoc} */ public boolean remove(Broadcaster b, Object id) { - boolean removed = (store.get(b.getID()) == b); + boolean removed = store.remove(id, b); if (removed) { - store.remove(id, b); logger.debug("Removing Broadcaster {} which internal reference is {} ", id, b.getID()); } return removed; @@ -216,10 +200,10 @@ public final Broadcaster lookup(Object id, boolean createIfNull) { */ @Override public Broadcaster lookup(Class c, Object id, boolean createIfNull) { - Broadcaster b = getBroadcaster(id); + Broadcaster b = store.get(id); if (b != null && !c.isAssignableFrom(b.getClass())) { String msg = "Invalid lookup class " + c.getName() + ". Cached class is: " + b.getClass().getName(); - logger.debug("{}", msg); + logger.debug(msg); throw new IllegalStateException(msg); } @@ -228,7 +212,10 @@ public Broadcaster lookup(Class c, Object id, boolean cre logger.debug("Removing destroyed Broadcaster {}", b.getID()); store.remove(b.getID(), b); } - b = get(c, id); + if (store.putIfAbsent(id, createBroadcaster(c, id)) == null) { + logger.debug("Added Broadcaster {} . Factory size: {}", id, store.size()); + } + b = store.get(id); } return b; diff --git a/modules/cpr/src/test/java/org/atmosphere/cpr/DefaultBroadcasterFactoryTest.java b/modules/cpr/src/test/java/org/atmosphere/cpr/DefaultBroadcasterFactoryTest.java new file mode 100644 index 00000000000..88fba8682a5 --- /dev/null +++ b/modules/cpr/src/test/java/org/atmosphere/cpr/DefaultBroadcasterFactoryTest.java @@ -0,0 +1,118 @@ +/* + * To change this template, choose Tools | Templates + * and open the template in the editor. + */ +package org.atmosphere.cpr; + +import org.atmosphere.cpr.AtmosphereServlet.AtmosphereConfig; + +import org.atmosphere.util.SimpleBroadcaster; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; +import static org.mockito.Mockito.*; + +/** + * + * @author jburgess + */ +public class DefaultBroadcasterFactoryTest { + + private AtmosphereConfig config; + private DefaultBroadcasterFactory factory; + + @BeforeMethod + public void setUp() throws Exception { + config = mock(AtmosphereConfig.class); + factory = new DefaultBroadcasterFactory(DefaultBroadcaster.class, "NEVER", config); + } + + @Test + public void testGet_0args() { + Broadcaster result = factory.get(); + assert result != null; + assert result instanceof DefaultBroadcaster; + } + + @Test + public void testGet_Object() { + String id = "id"; + Broadcaster result = factory.get(id); + assert result != null; + assert result instanceof DefaultBroadcaster; + assert id.equals(result.getID()); + } + + @Test(expectedExceptions = IllegalStateException.class) + public void testGet_Object_Twice() { + String id = "id"; + factory.get(id); + factory.get(id); + } + + @Test + public void testAdd() { + String id = "id"; + String id2 = "foo"; + Broadcaster b = factory.get(id); + assert factory.add(b, id) == false; + assert factory.lookup(id) != null; + assert factory.add(b, id2) == true; + assert factory.lookup(id2) != null; + } + + @Test + public void testRemove() { + String id = "id"; + String id2 = "foo"; + Broadcaster b = factory.get(id); + Broadcaster b2 = factory.get(id2); + assert factory.remove(b, id2) == false; + assert factory.remove(b2, id) == false; + assert factory.remove(b, id) == true; + assert factory.lookup(id) == null; + } + + @Test + public void testLookup_Class_Object() { + String id = "id"; + String id2 = "foo"; + Broadcaster b = factory.get(id); + assert factory.lookup(DefaultBroadcaster.class, id) != null; + assert factory.lookup(DefaultBroadcaster.class, id2) == null; + } + + @Test(expectedExceptions = IllegalStateException.class) + public void testLookup_Class_Object_BadClass() { + String id = "id"; + factory.get(id); + factory.lookup(SimpleBroadcaster.class, id); + } + + @Test + public void testLookup_Object() { + String id = "id"; + String id2 = "foo"; + factory.get(id); + assert factory.lookup(id) != null; + assert factory.lookup(id2) == null; + } + + @Test + public void testLookup_Object_boolean() { + String id = "id"; + assert factory.lookup(id, false) == null; + Broadcaster b = factory.lookup(id, true); + assert b != null; + assert id.equals(b.getID()); + } + + @Test + public void testLookup_Class_Object_boolean() { + String id = "id"; + assert factory.lookup(DefaultBroadcaster.class, id, false) == null; + Broadcaster b = factory.lookup(DefaultBroadcaster.class, id, true); + assert b != null; + assert b instanceof DefaultBroadcaster; + assert id.equals(b.getID()); + } +}