Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sonarfixes #1

Merged
merged 7 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.security.Security;
import java.util.Arrays;
import java.util.HashSet;
Expand Down Expand Up @@ -67,20 +68,15 @@ protected static String getEncryptedModelPath(String modelName, String folderPat
return files[0].toString();
}

protected static void createDecryptionFolder(String folderPath) throws IOException {
Path targetFolderPath = Paths.get(folderPath);

if (Files.exists(targetFolderPath)) {
throw new IOException("Target path " + targetFolderPath.toString() + " already exists");
}

logger.debug("Creating decryption folder at path: {}", folderPath);

Files.createDirectories(targetFolderPath);

protected static String createDecryptionFolder(String prefix) throws IOException {
Set<PosixFilePermission> permissions = new HashSet<>(Arrays.asList(PosixFilePermission.OWNER_READ,
PosixFilePermission.OWNER_WRITE, PosixFilePermission.OWNER_EXECUTE));
Files.setPosixFilePermissions(targetFolderPath, permissions);
Path tempFolderPath = Files.createTempDirectory(prefix, PosixFilePermissions.asFileAttribute(permissions));
// TODO shutdown hook tempFolderPath.toFile().deleteOnExit();

logger.debug("Created temporary directory at path {}", tempFolderPath);

return tempFolderPath.toString();
}

protected static void decryptModel(String password, String inputFilePath, String outputFilePath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class TritonServerLocalManager {
private static final int MONITOR_PERIOD = 30;
private static final String[] TRITONSERVER = new String[] { "tritonserver" };

private final String decryptionFolderPath;
private final CommandExecutorService commandExecutorService;
private final TritonServerServiceOptions options;
private Command serverCommand;
Expand All @@ -44,9 +45,10 @@ public class TritonServerLocalManager {
private ScheduledFuture<?> scheduledFuture;

protected TritonServerLocalManager(TritonServerServiceOptions options,
CommandExecutorService commandExecutorService) {
CommandExecutorService commandExecutorService, String decryptionFolderPath) {
this.options = options;
this.commandExecutorService = commandExecutorService;
this.decryptionFolderPath = decryptionFolderPath;
this.scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
}

Expand Down Expand Up @@ -137,8 +139,8 @@ protected static void sleepFor(long timeout) {
private Command createServerCommand() {
List<String> commandString = new ArrayList<>();
commandString.add("tritonserver");
if (!this.options.getModelRepositoryPassword().isEmpty()) {
commandString.add("--model-repository=" + TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
if (this.options.modelsAreEncrypted()) {
commandString.add("--model-repository=" + this.decryptionFolderPath);
} else {
commandString.add("--model-repository=" + this.options.getModelRepositoryPath());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -75,6 +74,7 @@
public class TritonServerServiceImpl implements InferenceEngineService, ConfigurableComponent {

private static final Logger logger = LoggerFactory.getLogger(TritonServerServiceImpl.class);
private static final String TEMP_DIRECTORY_PREFIX = "decrypted_models";

private CommandExecutorService commandExecutorService;
private CryptoService cryptoService;
Expand All @@ -83,6 +83,7 @@ public class TritonServerServiceImpl implements InferenceEngineService, Configur

private ManagedChannel grpcChannel;
private GRPCInferenceServiceBlockingStub grpcStub;
private String decryptionFolderPath = "";

public void setCommandExecutorService(CommandExecutorService executorService) {
this.commandExecutorService = executorService;
Expand All @@ -98,19 +99,7 @@ protected void activate(Map<String, Object> properties) {
if (isConfigurationValid()) {
setGrpcResources();
if (this.options.isLocalEnabled()) {
if (!this.options.getModelRepositoryPassword().isEmpty()
&& !Files.isDirectory(Paths.get(TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH))) {
logger.info("Creating decryption model directory at {}",
TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
try {
TritonServerEncryptionUtils
.createDecryptionFolder(TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
} catch (IOException e) {
logger.warn("Failed to create decryption model directory", e);
}
}
this.tritonServerLocalManager = new TritonServerLocalManager(this.options, this.commandExecutorService);
this.tritonServerLocalManager.start();
startLocalInstance();
}
loadModels();
} else {
Expand All @@ -127,19 +116,7 @@ public void updated(Map<String, Object> properties) {
if (isConfigurationValid()) {
setGrpcResources();
if (this.options.isLocalEnabled()) {
if (!this.options.getModelRepositoryPassword().isEmpty()
&& !Files.isDirectory(Paths.get(TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH))) {
logger.info("Creating decryption model directory at {}",
TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
try {
TritonServerEncryptionUtils
.createDecryptionFolder(TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
} catch (IOException e) {
logger.warn("Failed to create decryption model directory", e);
}
}
this.tritonServerLocalManager = new TritonServerLocalManager(this.options, this.commandExecutorService);
this.tritonServerLocalManager.start();
startLocalInstance();
} else {
this.tritonServerLocalManager = null;
}
Expand All @@ -149,6 +126,21 @@ public void updated(Map<String, Object> properties) {
}
}

private void startLocalInstance() {
if (this.options.modelsAreEncrypted()) {
try {
decryptionFolderPath = TritonServerEncryptionUtils.createDecryptionFolder(TEMP_DIRECTORY_PREFIX);
} catch (IOException e) {
logger.warn("Failed to create decryption model directory", e);
}

logger.info("Created decryption model directory at {}", decryptionFolderPath);
}
this.tritonServerLocalManager = new TritonServerLocalManager(this.options, this.commandExecutorService,
this.decryptionFolderPath);
this.tritonServerLocalManager.start();
}

protected void deactivate() {
logger.info("Deactivate TritonServerService...");
if (nonNull(this.tritonServerLocalManager)) {
Expand Down Expand Up @@ -203,20 +195,18 @@ protected void loadModels() {

@Override
public void loadModel(String modelName, Optional<String> modelPath) throws KuraException {
String password = this.options.getModelRepositoryPassword();

if (!password.isEmpty()) {
if (this.options.modelsAreEncrypted()) {
String password = this.options.getModelRepositoryPassword();
String plainPassword = String.valueOf(cryptoService.decryptAes(password.toCharArray()));
String encryptedModelPath = TritonServerEncryptionUtils.getEncryptedModelPath(modelName,
this.options.getModelRepositoryPath());
String decryptedModelPath = TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH + modelName + ".zip";
String decryptedModelPath = Paths.get(decryptionFolderPath, modelName + ".zip").toString();

logger.info("Model decryption password detected. Decrypting model {} at {} into {}", modelName,
encryptedModelPath, decryptedModelPath);
try {
TritonServerEncryptionUtils.decryptModel(plainPassword, encryptedModelPath, decryptedModelPath);
TritonServerEncryptionUtils.unzipModel(decryptedModelPath,
TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
TritonServerEncryptionUtils.unzipModel(decryptedModelPath, decryptionFolderPath);
} catch (KuraIOException | IOException e) {
throw new KuraIOException(e, "Cannot decrypt the model " + modelName);
}
Expand All @@ -227,11 +217,13 @@ public void loadModel(String modelName, Optional<String> modelPath) throws KuraE
try {
this.grpcStub.repositoryModelLoad(builder.build());
} catch (StatusRuntimeException e) {
TritonServerEncryptionUtils.cleanRepository(TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
if (this.options.modelsAreEncrypted()) {
TritonServerEncryptionUtils.cleanRepository(decryptionFolderPath);
}
throw new KuraIOException(e, "Cannot load the model " + modelName);
}

if (!password.isEmpty()) {
if (this.options.modelsAreEncrypted()) {
int counter = 0;
while (!isModelLoaded(modelName)) {
if (counter++ >= 6) {
Expand All @@ -240,7 +232,7 @@ public void loadModel(String modelName, Optional<String> modelPath) throws KuraE
}
TritonServerLocalManager.sleepFor(250);
}
TritonServerEncryptionUtils.cleanRepository(TritonServerServiceOptions.DECRYPTED_MODELS_REPO_PATH);
TritonServerEncryptionUtils.cleanRepository(decryptionFolderPath);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

public class TritonServerServiceOptions {

public static final String DECRYPTED_MODELS_REPO_PATH = "/tmp/decrypted_models/";

private static final String PROPERTY_ADDRESS = "server.address";
private static final String PROPERTY_PORTS = "server.ports";
private static final String PROPERTY_LOCAL_MODEL_REPOSITORY_PATH = "local.model.repository.path";
Expand Down Expand Up @@ -117,6 +115,10 @@ public String getBackendsPath() {
return getStringProperty(PROPERTY_LOCAL_BACKENDS_PATH);
}

public boolean modelsAreEncrypted() {
return !getModelRepositoryPassword().isEmpty();
}

public List<String> getBackendsConfigs() {
List<String> backendsConfigs = new ArrayList<>();
final Object propertyBackendsConfig = this.properties.get(PROPERTY_LOCAL_BACKENDS_CONFIG);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class TritonServerEncryptionUtilsTest {

private static final String WORKDIR = System.getProperty("java.io.tmpdir") + "/decr_folder";
private boolean exceptionOccurred = false;
private String tempDirectoryPrefix;
private String targetFolder;
private String modelName;
private String expectedEncryptedModelPath;
Expand Down Expand Up @@ -114,38 +115,15 @@ public void getEncryptedModelPathShouldThrowIfMultipleMatchesFound() {

@Test
public void createDecryptionFolderShouldWork() {
givenTargetFolder(WORKDIR + "/target_folder");
givenNoFileExistsAtPath(targetFolder);

whenCreateDecryptionFolderIsCalledWith(targetFolder);

thenAFolderExistsAtPath(targetFolder);
thenTargetFolderHasPermissions(targetFolder, "rwx------");
thenNoExceptionOccurred();
}

@Test
public void createDecryptionFolderShouldWorkWithNestedPath() {
givenTargetFolder(WORKDIR + "/new/nested/folder");
givenNoFileExistsAtPath(targetFolder);
givenTempDirectoryPrefix("prefix");

whenCreateDecryptionFolderIsCalledWith(targetFolder);
whenCreateDecryptionFolderIsCalledWith(tempDirectoryPrefix);

thenAFolderExistsAtPath(targetFolder);
thenTargetFolderHasPermissions(targetFolder, "rwx------");
thenNoExceptionOccurred();
}

@Test
public void createDecryptionFolderShouldThrowOnNameClashes() {
givenTargetFolder(WORKDIR + "/another_folder");
givenAFileAreadyExistsAtPath(targetFolder);

whenCreateDecryptionFolderIsCalledWith(targetFolder);

thenAnExceptionOccurred();
}

@Test
public void decryptModelShouldWork() {
givenEncryptedFileAtPath("target/test-classes/plain_file.gpg");
Expand Down Expand Up @@ -328,6 +306,10 @@ private void givenExpectedModelPath(String modelPath) {
this.expectedEncryptedModelPath = modelPath;
}

private void givenTempDirectoryPrefix(String prefix) {
this.tempDirectoryPrefix = prefix;
}

private void givenTargetFolder(String folderPath) {
this.targetFolder = folderPath;
}
Expand Down Expand Up @@ -411,7 +393,7 @@ private void whenGetEncryptedModelPathIsCalledWith(String modelName, String fold

private void whenCreateDecryptionFolderIsCalledWith(String folderPath) {
try {
TritonServerEncryptionUtils.createDecryptionFolder(folderPath);
this.targetFolder = TritonServerEncryptionUtils.createDecryptionFolder(folderPath);
} catch (IOException e) {
e.printStackTrace();
this.exceptionOccurred = true;
Expand Down