Skip to content

Commit

Permalink
#6804: fix static import magic command
Browse files Browse the repository at this point in the history
  • Loading branch information
jaroslawmalekcodete committed Feb 9, 2018
1 parent fa28a12 commit 8c29f59
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 5 deletions.
54 changes: 54 additions & 0 deletions doc/groovy/ClasspathMagicCommands.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,60 @@
"source": [
"%import static com.example.Demo.staticTest"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"staticTest()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%import static com.example.Demo.STATIC_TEST_123"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"STATIC_TEST_123"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%import static com.example.Demo.*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"STATIC_TEST_123"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"staticTest()"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ public String asString() {
return anImport;
}

public boolean isStatic() {
return anImport.startsWith("static");
}

@Override
public boolean equals(Object o) {
return reflectionEquals(this, o);
Expand All @@ -48,4 +52,8 @@ public int hashCode() {
public String toString() {
return reflectionToString(this);
}

public String path() {
return isStatic() ? anImport.replace("static ", "") : anImport;
}
}
68 changes: 64 additions & 4 deletions kernel/base/src/main/java/com/twosigma/beakerx/kernel/Imports.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
*/
package com.twosigma.beakerx.kernel;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.reflect.ClassPath;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -30,11 +32,11 @@

public class Imports {

private List<ImportPath> imports = new ArrayList<>();
private List<ImportPath> imports;
private List<String> importsAsStrings = null;

public Imports(List<ImportPath> importPaths) {
this.imports = Preconditions.checkNotNull(importPaths);
this.imports = checkNotNull(importPaths);
}

public List<ImportPath> getImportPaths() {
Expand Down Expand Up @@ -92,14 +94,72 @@ public String toString() {
}

private boolean isImportPathValid(ImportPath anImport, ClassLoader classLoader) {
String importToCheck = anImport.asString();
if (anImport.isStatic()) {
return isValidStaticImport(anImport, classLoader);
} else {
return isValidImport(anImport, classLoader);
}
}

private boolean isValidStaticImport(ImportPath anImport, ClassLoader classLoader) {
String importToCheck = anImport.path();
if (!importToCheck.contains(".")) {
return false;
}
if (importToCheck.endsWith(".*")) {
return isValidStaticImportWithWildcard(classLoader, importToCheck);
} else {
return isValidStatic(classLoader, importToCheck);
}
}

private boolean isValidImport(ImportPath anImport, ClassLoader classLoader) {
String importToCheck = anImport.path();
if (importToCheck.endsWith(".*")) {
return isValidImportWithWildcard(importToCheck, classLoader);
} else {
return isValidClassImport(importToCheck, classLoader);
}
}

private boolean isValidStaticImportWithWildcard(ClassLoader classLoader, String importToCheck) {
String classImport = importToCheck.substring(0, importToCheck.lastIndexOf("."));
return isValidClassImport(classImport, classLoader);
}

private boolean isValidStatic(ClassLoader classLoader, String importToCheck) {
String packageImport = importToCheck.substring(0, importToCheck.lastIndexOf("."));
boolean validClassImport = isValidClassImport(packageImport, classLoader);
if (validClassImport) {
String methodOrName = importToCheck.substring(importToCheck.lastIndexOf("."), importToCheck.length()).replaceFirst(".", "");
if (methodOrName.isEmpty()) {
return false;
}
try {
Class<?> aClass = classLoader.loadClass(packageImport);
List<Method> methods = getMethods(methodOrName, aClass);
if (!methods.isEmpty()) {
return true;
}
List<Field> fields = getFields(methodOrName, aClass);
return !fields.isEmpty();
} catch (ClassNotFoundException e) {
return false;
}
}
return false;
}

private List<Field> getFields(String methodOrName, Class<?> aClass) {
Field[] publicFields = aClass.getFields();
return Arrays.stream(publicFields).filter(x -> x.getName().equals(methodOrName)).collect(Collectors.toList());
}

private List<Method> getMethods(String methodOrName, Class<?> aClass) {
Method[] publicMethods = aClass.getMethods();
return Arrays.stream(publicMethods).filter(x -> x.getName().equals(methodOrName)).collect(Collectors.toList());
}

private boolean isValidClassImport(String importToCheck, ClassLoader classLoader) {
try {
classLoader.loadClass(importToCheck);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.twosigma.beakerx.kernel.magic.command.functionality;

import com.twosigma.beakerx.kernel.AddImportStatus;
import com.twosigma.beakerx.kernel.ImportPath;
import com.twosigma.beakerx.kernel.KernelFunctionality;
import com.twosigma.beakerx.kernel.magic.command.MagicCommandExecutionParam;
Expand All @@ -23,6 +24,7 @@
import com.twosigma.beakerx.kernel.magic.command.outcome.MagicCommandOutput;

import static com.twosigma.beakerx.kernel.magic.command.functionality.AddImportMagicCommand.IMPORT;
import static com.twosigma.beakerx.kernel.magic.command.outcome.MagicCommandOutcomeItem.Status.ERROR;

public class AddStaticImportMagicCommand implements MagicCommandFunctionality {

Expand Down Expand Up @@ -52,7 +54,10 @@ public MagicCommandOutcomeItem execute(MagicCommandExecutionParam param) {
if (parts.length != 3) {
return new MagicCommandOutput(MagicCommandOutput.Status.ERROR, WRONG_FORMAT_MSG);
}
this.kernel.addImport(new ImportPath(parts[1] + " " + parts[2]));
AddImportStatus status = this.kernel.addImport(new ImportPath(parts[1] + " " + parts[2]));
if (AddImportStatus.ERROR.equals(status)) {
return new MagicCommandOutput(ERROR, "Could not import static " + parts[2]);
}
return new MagicCommandOutput(MagicCommandOutput.Status.OK);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import com.twosigma.beakerx.KernelExecutionTest;
import com.twosigma.beakerx.kernel.CloseKernelAction;
import com.twosigma.beakerx.kernel.Code;
import com.twosigma.beakerx.kernel.Kernel;
import com.twosigma.beakerx.kernel.KernelSocketsFactory;
import com.twosigma.beakerx.kernel.Utils;
import com.twosigma.beakerx.kernel.comm.Comm;
import com.twosigma.beakerx.kernel.magic.command.CodeFactory;
import com.twosigma.beakerx.kernel.msg.JupyterMessages;
import com.twosigma.beakerx.message.Header;
import com.twosigma.beakerx.message.Message;
Expand All @@ -37,9 +39,11 @@
import static com.twosigma.beakerx.MessageFactoryTest.getExecuteRequestMessage;
import static com.twosigma.beakerx.evaluator.EvaluatorResultTestWatcher.waitForIdleMessage;
import static com.twosigma.beakerx.evaluator.EvaluatorResultTestWatcher.waitForResult;
import static com.twosigma.beakerx.evaluator.EvaluatorResultTestWatcher.waitForStderr;
import static com.twosigma.beakerx.evaluator.EvaluatorResultTestWatcher.waitForUpdateMessage;
import static com.twosigma.beakerx.evaluator.EvaluatorTest.getCacheFolderFactory;
import static com.twosigma.beakerx.groovy.TestGroovyEvaluator.groovyEvaluator;
import static com.twosigma.beakerx.kernel.magic.command.functionality.AddStaticImportMagicCommand.ADD_STATIC_IMPORT;
import static org.assertj.core.api.Assertions.assertThat;

public class GroovyKernelTest extends KernelExecutionTest {
Expand Down Expand Up @@ -146,4 +150,105 @@ private Message outputWidgetUpdateMessage(String outputCommId) {
return message;
}

@Test
public void shouldImportStaticWildcardDemoClassByMagicCommand() throws Exception {
//given
addDemoJar();
String path = pathToDemoClassFromAddedDemoJar() + ".*";
//when
Code code = CodeFactory.create(ADD_STATIC_IMPORT + " " + path, new Message(), getKernel());
code.execute(kernel, 1);
//then
verifyStaticImportedDemoClassByMagicCommand(pathToDemoClassFromAddedDemoJar() + ".staticTest()");
verifyStaticImportedDemoClassByMagicCommand(pathToDemoClassFromAddedDemoJar() + ".STATIC_TEST_123");
}

@Test
public void shouldImportStaticMethodDemoClassByMagicCommand() throws Exception {
//given
addDemoJar();
String path = pathToDemoClassFromAddedDemoJar() + ".staticTest";
//when
Code code = CodeFactory.create(ADD_STATIC_IMPORT + " " + path, new Message(), getKernel());
code.execute(kernel, 1);
//then
verifyStaticImportedDemoClassByMagicCommand(pathToDemoClassFromAddedDemoJar() + ".staticTest()");
}

@Test
public void shouldImportStaticFieldDemoClassByMagicCommand() throws Exception {
//given
addDemoJar();
String path = pathToDemoClassFromAddedDemoJar() + ".STATIC_TEST_123";
//when
Code code = CodeFactory.create(ADD_STATIC_IMPORT + " " + path, new Message(), getKernel());
code.execute(kernel, 1);
//then
verifyStaticImportedDemoClassByMagicCommand(pathToDemoClassFromAddedDemoJar() + ".STATIC_TEST_123");
}

protected void verifyStaticImportedDemoClassByMagicCommand(String path) throws InterruptedException {
Message message = getExecuteRequestMessage(path);
getKernelSocketsService().handleMsg(message);
Optional<Message> idleMessage = waitForIdleMessage(getKernelSocketsService().getKernelSockets());
assertThat(idleMessage).isPresent();
Optional<Message> result = waitForResult(getKernelSocketsService().getKernelSockets());
Map actual = ((Map) result.get().getContent().get(Comm.DATA));
String value = (String) actual.get("text/plain");
assertThat(value).isEqualTo("Demo_static_test_123");
}

@Test
public void shouldNotImportStaticUnknownClassByMagicCommand() throws Exception {
//given
String allCode = ADD_STATIC_IMPORT + " " + pathToDemoClassFromAddedDemoJar() + "UnknownClass";
//when
Code code = CodeFactory.create(allCode, new Message(), getKernel());
code.execute(kernel, 1);
//then
verifyNotImportedStaticMagicCommand();
}

@Test
public void shouldNotImportStaticUnknownFieldDemoClassByMagicCommand() throws Exception {
//given
addDemoJar();
String path = pathToDemoClassFromAddedDemoJar() + ".STATIC_TEST_123_unknown";
//when
Code code = CodeFactory.create(ADD_STATIC_IMPORT + " " + path, new Message(), getKernel());
code.execute(kernel, 1);
//then
verifyNotImportedStaticMagicCommand();
}

@Test
public void shouldNotImportStaticUnknownMethodDemoClassByMagicCommand() throws Exception {
//given
addDemoJar();
String path = pathToDemoClassFromAddedDemoJar() + ".staticTest_unknown";
//when
Code code = CodeFactory.create(ADD_STATIC_IMPORT + " " + path, new Message(), getKernel());
code.execute(kernel, 1);
//then
verifyNotImportedStaticMagicCommand();
}

@Test
public void shouldNotImportStaticNotStaticPathByMagicCommand() throws Exception {
//given
addDemoJar();
String path = "garbage";
//when
Code code = CodeFactory.create(ADD_STATIC_IMPORT + " " + path, new Message(), getKernel());
code.execute(kernel, 1);
//then
verifyNotImportedStaticMagicCommand();
}

private void verifyNotImportedStaticMagicCommand() throws InterruptedException {
List<Message> std = waitForStderr(getKernelSocketsService().getKernelSockets());
String text = (String) std.get(0).getContent().get("text");
assertThat(text).contains("Could not import static");
}

}

0 comments on commit 8c29f59

Please sign in to comment.