diff --git a/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraClient.java b/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraClient.java index 3c809fa798..0785741248 100644 --- a/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraClient.java +++ b/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraClient.java @@ -14,6 +14,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.event.EventType; @@ -87,7 +88,7 @@ void injectObjectMapper(ObjectMapper objectMapper) { @Override public void executePartition(SaasWorkerProgressState state, Buffer> buffer, - CrawlerSourceConfig configuration) { + AcknowledgementSet acknowledgementSet) { log.trace("Executing the partition: {} with {} ticket(s)", state.getKeyAttributes(), state.getItemIds().size()); List itemIds = state.getItemIds(); @@ -130,7 +131,13 @@ public void executePartition(SaasWorkerProgressState state, .collect(Collectors.toList()); try { - buffer.writeAll(recordsToWrite, (int) Duration.ofSeconds(bufferWriteTimeoutInSeconds).toMillis()); + if (configuration.isAcknowledgments()) { + recordsToWrite.forEach(eventRecord -> acknowledgementSet.add(eventRecord.getData())); + buffer.writeAll(recordsToWrite, (int) Duration.ofSeconds(bufferWriteTimeoutInSeconds).toMillis()); + acknowledgementSet.complete(); + } else { + buffer.writeAll(recordsToWrite, (int) Duration.ofSeconds(bufferWriteTimeoutInSeconds).toMillis()); + } } catch (Exception e) { throw new RuntimeException(e); } diff --git a/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraSourceConfig.java b/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraSourceConfig.java index df5cd70f0b..ed873fc49b 100644 --- a/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraSourceConfig.java +++ b/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/JiraSourceConfig.java @@ -59,6 +59,12 @@ public class JiraSourceConfig implements CrawlerSourceConfig { @JsonProperty("backoff_time") private Duration backOff = DEFAULT_BACKOFF_MILLIS; + /** + * Boolean property indicating end to end acknowledgments state + */ + @JsonProperty("acknowledgments") + private boolean acknowledgments = false; + public String getAccountUrl() { return this.getHosts().get(0); } diff --git a/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/rest/auth/JiraOauthConfig.java b/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/rest/auth/JiraOauthConfig.java index 9228fa82df..ddcf1c8468 100644 --- a/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/rest/auth/JiraOauthConfig.java +++ b/data-prepper-plugins/saas-source-plugins/jira-source/src/main/java/org/opensearch/dataprepper/plugins/source/jira/rest/auth/JiraOauthConfig.java @@ -12,6 +12,7 @@ import lombok.Getter; import org.opensearch.dataprepper.plugins.source.jira.JiraSourceConfig; +import org.opensearch.dataprepper.plugins.source.jira.configuration.Oauth2Config; import org.opensearch.dataprepper.plugins.source.jira.exception.UnAuthorizedException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -126,18 +127,17 @@ public void renewCredentials() { String payload = String.format(payloadTemplate, "refresh_token", clientId, clientSecret, refreshToken); HttpEntity entity = new HttpEntity<>(payload, headers); + Oauth2Config oauth2Config = jiraSourceConfig.getAuthenticationConfig().getOauth2Config(); try { ResponseEntity responseEntity = restTemplate.postForEntity(TOKEN_LOCATION, entity, Map.class); Map oauthClientResponse = responseEntity.getBody(); this.accessToken = (String) oauthClientResponse.get(ACCESS_TOKEN); this.refreshToken = (String) oauthClientResponse.get(REFRESH_TOKEN); this.expiresInSeconds = (int) oauthClientResponse.get(EXPIRES_IN); - this.expireTime = Instant.ofEpochMilli(System.currentTimeMillis() + (expiresInSeconds * 1000L)); + this.expireTime = Instant.now().plusSeconds(expiresInSeconds); // updating config object's PluginConfigVariable so that it updates the underlying Secret store - jiraSourceConfig.getAuthenticationConfig().getOauth2Config().getAccessToken() - .setValue(this.accessToken); - jiraSourceConfig.getAuthenticationConfig().getOauth2Config().getRefreshToken() - .setValue(this.refreshToken); + oauth2Config.getAccessToken().setValue(this.accessToken); + oauth2Config.getRefreshToken().setValue(this.refreshToken); log.info("Access Token and Refresh Token pair is now refreshed. Corresponding Secret store key updated."); } catch (HttpClientErrorException ex) { this.expireTime = Instant.ofEpochMilli(0); @@ -147,9 +147,12 @@ public void renewCredentials() { statusCode, ex.getMessage()); if (statusCode == HttpStatus.FORBIDDEN || statusCode == HttpStatus.UNAUTHORIZED) { log.info("Trying to refresh the secrets"); - // Try refreshing the secrets and see if that helps - // Refreshing one of the secret refreshes the entire store so we are good to trigger refresh on just one - jiraSourceConfig.getAuthenticationConfig().getOauth2Config().getAccessToken().refresh(); + // Refreshing the secrets. It should help if someone already renewed the tokens. + // Refreshing one of the secret refreshes the entire store so triggering refresh on just one + oauth2Config.getAccessToken().refresh(); + this.accessToken = (String) oauth2Config.getAccessToken().getValue(); + this.refreshToken = (String) oauth2Config.getRefreshToken().getValue(); + this.expireTime = Instant.now().plusSeconds(10); } throw new RuntimeException("Failed to renew access token message:" + ex.getMessage(), ex); } diff --git a/data-prepper-plugins/saas-source-plugins/jira-source/src/test/java/org/opensearch/dataprepper/plugins/source/jira/JiraClientTest.java b/data-prepper-plugins/saas-source-plugins/jira-source/src/test/java/org/opensearch/dataprepper/plugins/source/jira/JiraClientTest.java index 78531afd61..58df280ae9 100644 --- a/data-prepper-plugins/saas-source-plugins/jira-source/src/test/java/org/opensearch/dataprepper/plugins/source/jira/JiraClientTest.java +++ b/data-prepper-plugins/saas-source-plugins/jira-source/src/test/java/org/opensearch/dataprepper/plugins/source/jira/JiraClientTest.java @@ -18,10 +18,10 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.plugins.source.source_crawler.base.CrawlerSourceConfig; import org.opensearch.dataprepper.plugins.source.source_crawler.base.PluginExecutorServiceProvider; import org.opensearch.dataprepper.plugins.source.source_crawler.coordination.state.SaasWorkerProgressState; @@ -46,26 +46,21 @@ @ExtendWith(MockitoExtension.class) public class JiraClientTest { + private final PluginExecutorServiceProvider executorServiceProvider = new PluginExecutorServiceProvider(); @Mock private Buffer> buffer; - @Mock private SaasWorkerProgressState saasWorkerProgressState; @Mock - private CrawlerSourceConfig crawlerSourceConfig; - + private AcknowledgementSet acknowledgementSet; @Mock private JiraSourceConfig jiraSourceConfig; - @Mock private JiraService jiraService; - @Mock private JiraIterator jiraIterator; - private PluginExecutorServiceProvider executorServiceProvider = new PluginExecutorServiceProvider(); - @Test void testConstructor() { JiraClient jiraClient = new JiraClient(jiraService, jiraIterator, executorServiceProvider, jiraSourceConfig); @@ -98,7 +93,7 @@ void testExecutePartition() throws Exception { ArgumentCaptor>> recordsCaptor = ArgumentCaptor.forClass((Class) Collection.class); - jiraClient.executePartition(saasWorkerProgressState, buffer, crawlerSourceConfig); + jiraClient.executePartition(saasWorkerProgressState, buffer, acknowledgementSet); verify(buffer).writeAll(recordsCaptor.capture(), anyInt()); Collection> capturedRecords = recordsCaptor.getValue(); @@ -121,14 +116,13 @@ void testExecutePartitionError() throws Exception { when(jiraService.getIssue(anyString())).thenReturn("{\"id\":\"ID1\",\"key\":\"TEST-1\"}"); - ArgumentCaptor>> recordsCaptor = ArgumentCaptor.forClass((Class) Collection.class); ObjectMapper mockObjectMapper = mock(ObjectMapper.class); when(mockObjectMapper.readValue(any(String.class), any(TypeReference.class))).thenThrow(new JsonProcessingException("test") { }); jiraClient.injectObjectMapper(mockObjectMapper); - assertThrows(RuntimeException.class, () -> jiraClient.executePartition(saasWorkerProgressState, buffer, crawlerSourceConfig)); + assertThrows(RuntimeException.class, () -> jiraClient.executePartition(saasWorkerProgressState, buffer, acknowledgementSet)); } @Test @@ -147,6 +141,6 @@ void bufferWriteRuntimeTest() throws Exception { ArgumentCaptor>> recordsCaptor = ArgumentCaptor.forClass((Class) Collection.class); doThrow(new RuntimeException()).when(buffer).writeAll(recordsCaptor.capture(), anyInt()); - assertThrows(RuntimeException.class, () -> jiraClient.executePartition(saasWorkerProgressState, buffer, crawlerSourceConfig)); + assertThrows(RuntimeException.class, () -> jiraClient.executePartition(saasWorkerProgressState, buffer, acknowledgementSet)); } } \ No newline at end of file diff --git a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/Crawler.java b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/Crawler.java index 327fff0d89..9634289344 100644 --- a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/Crawler.java +++ b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/Crawler.java @@ -2,6 +2,7 @@ import io.micrometer.core.instrument.Timer; import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -23,7 +24,7 @@ @Named public class Crawler { private static final Logger log = LoggerFactory.getLogger(Crawler.class); - private static final int maxItemsPerPage = 50; + private static final int maxItemsPerPage = 100; private final Timer crawlingTimer; private final PluginMetrics pluginMetrics = PluginMetrics.fromNames("sourceCrawler", "crawler"); @@ -61,8 +62,8 @@ public Instant crawl(Instant lastPollTime, return Instant.ofEpochMilli(startTime); } - public void executePartition(SaasWorkerProgressState state, Buffer> buffer, CrawlerSourceConfig sourceConfig) { - client.executePartition(state, buffer, sourceConfig); + public void executePartition(SaasWorkerProgressState state, Buffer> buffer, AcknowledgementSet acknowledgementSet) { + client.executePartition(state, buffer, acknowledgementSet); } private void createPartition(List itemInfoList, EnhancedSourceCoordinator coordinator) { diff --git a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerClient.java b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerClient.java index f086d916b3..47b033f66a 100644 --- a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerClient.java +++ b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerClient.java @@ -1,5 +1,6 @@ package org.opensearch.dataprepper.plugins.source.source_crawler.base; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -36,9 +37,9 @@ public interface CrawlerClient { /** * Method for executing a particular partition or a chunk of work * - * @param state worker node state holds the details of this particular chunk of work - * @param buffer pipeline buffer to write the results into - * @param sourceConfig pipeline configuration from the yaml + * @param state worker node state holds the details of this particular chunk of work + * @param buffer pipeline buffer to write the results into + * @param acknowledgementSet acknowledgement set to be used to track the completion of the partition */ - void executePartition(SaasWorkerProgressState state, Buffer> buffer, CrawlerSourceConfig sourceConfig); + void executePartition(SaasWorkerProgressState state, Buffer> buffer, AcknowledgementSet acknowledgementSet); } diff --git a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourceConfig.java b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourceConfig.java index 18649e052c..77902bdbc7 100644 --- a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourceConfig.java +++ b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourceConfig.java @@ -6,4 +6,11 @@ public interface CrawlerSourceConfig { int DEFAULT_NUMBER_OF_WORKERS = 1; + + /** + * Boolean to indicate if acknowledgments enabled for this source + * + * @return boolean indicating acknowledgement state + */ + boolean isAcknowledgments(); } diff --git a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourcePlugin.java b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourcePlugin.java index d959033c9c..70c0182e27 100644 --- a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourcePlugin.java +++ b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerSourcePlugin.java @@ -78,7 +78,8 @@ public void start(Buffer> buffer) { this.executorService.submit(leaderScheduler); //Register worker threaders for (int i = 0; i < sourceConfig.DEFAULT_NUMBER_OF_WORKERS; i++) { - WorkerScheduler workerScheduler = new WorkerScheduler(buffer, coordinator, sourceConfig, crawler); + WorkerScheduler workerScheduler = new WorkerScheduler(sourcePluginName, buffer, coordinator, + sourceConfig, crawler, pluginMetrics, acknowledgementSetManager); this.executorService.submit(new Thread(workerScheduler)); } } diff --git a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerScheduler.java b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerScheduler.java index 5569080f77..f2fc7e4b40 100644 --- a/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerScheduler.java +++ b/data-prepper-plugins/saas-source-plugins/source-crawler/src/main/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerScheduler.java @@ -1,5 +1,9 @@ package org.opensearch.dataprepper.plugins.source.source_crawler.coordination.scheduler; +import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -21,24 +25,41 @@ */ public class WorkerScheduler implements Runnable { + public static final String ACKNOWLEDGEMENT_SET_SUCCESS_METRIC_NAME = "acknowledgementSetSuccesses"; + public static final String ACKNOWLEDGEMENT_SET_FAILURES_METRIC_NAME = "acknowledgementSetFailures"; + private static final Duration ACKNOWLEDGEMENT_SET_TIMEOUT = Duration.ofSeconds(20); private static final Logger log = LoggerFactory.getLogger(WorkerScheduler.class); private static final int RETRY_BACKOFF_ON_EXCEPTION_MILLIS = 5_000; private static final Duration DEFAULT_SLEEP_DURATION_MILLIS = Duration.ofMillis(10000); - private final EnhancedSourceCoordinator sourceCoordinator; private final CrawlerSourceConfig sourceConfig; private final Crawler crawler; private final Buffer> buffer; + private final PluginMetrics pluginMetrics; + private final AcknowledgementSetManager acknowledgementSetManager; + private final Counter acknowledgementSetSuccesses; + private final Counter acknowledgementSetFailures; + private final String sourcePluginName; + private final String SOURCE_PLUGIN_NAME = "sourcePluginName"; - public WorkerScheduler(Buffer> buffer, + public WorkerScheduler(final String sourcePluginName, + Buffer> buffer, EnhancedSourceCoordinator sourceCoordinator, CrawlerSourceConfig sourceConfig, - Crawler crawler) { + Crawler crawler, + final PluginMetrics pluginMetrics, + final AcknowledgementSetManager acknowledgementSetManager) { this.sourceCoordinator = sourceCoordinator; this.sourceConfig = sourceConfig; this.crawler = crawler; this.buffer = buffer; + this.sourcePluginName = sourcePluginName; + + this.acknowledgementSetManager = acknowledgementSetManager; + this.pluginMetrics = pluginMetrics; + this.acknowledgementSetSuccesses = pluginMetrics.counterWithTags(ACKNOWLEDGEMENT_SET_SUCCESS_METRIC_NAME, SOURCE_PLUGIN_NAME, sourcePluginName); + this.acknowledgementSetFailures = pluginMetrics.counterWithTags(ACKNOWLEDGEMENT_SET_FAILURES_METRIC_NAME, SOURCE_PLUGIN_NAME, sourcePluginName); } @Override @@ -52,7 +73,7 @@ public void run() { sourceCoordinator.acquireAvailablePartition(SaasSourcePartition.PARTITION_TYPE); if (partition.isPresent()) { // Process the partition (source extraction logic) - processPartition(partition.get(), buffer, sourceConfig); + processPartition(partition.get(), buffer); } else { log.debug("No partition available. This thread will sleep for {}", DEFAULT_SLEEP_DURATION_MILLIS); @@ -75,13 +96,31 @@ public void run() { log.warn("SourceItemWorker Scheduler is interrupted, looks like shutdown has triggered"); } - private void processPartition(EnhancedSourcePartition partition, Buffer> buffer, CrawlerSourceConfig sourceConfig) { + private void processPartition(EnhancedSourcePartition partition, Buffer> buffer) { // Implement your source extraction logic here // Update the partition state or commit the partition as needed // Commit the partition to mark it as processed if (partition.getProgressState().isPresent()) { - crawler.executePartition((SaasWorkerProgressState) partition.getProgressState().get(), buffer, sourceConfig); + AcknowledgementSet acknowledgementSet = null; + if (sourceConfig.isAcknowledgments()) { + acknowledgementSet = createAcknowledgementSet(partition); + } + crawler.executePartition((SaasWorkerProgressState) partition.getProgressState().get(), buffer, acknowledgementSet); } sourceCoordinator.completePartition(partition); } + + private AcknowledgementSet createAcknowledgementSet(EnhancedSourcePartition partition) { + return acknowledgementSetManager.create((result) -> { + if (result) { + acknowledgementSetSuccesses.increment(); + sourceCoordinator.completePartition(partition); + log.debug("acknowledgements received for partitionKey: {}", partition.getPartitionKey()); + } else { + acknowledgementSetFailures.increment(); + log.debug("acknowledgements received with false for partitionKey: {}", partition.getPartitionKey()); + sourceCoordinator.giveUpPartition(partition); + } + }, ACKNOWLEDGEMENT_SET_TIMEOUT); + } } diff --git a/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerTest.java b/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerTest.java index 45d2fcb402..2b3aab7fcc 100644 --- a/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerTest.java +++ b/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/base/CrawlerTest.java @@ -5,6 +5,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -32,7 +33,7 @@ @ExtendWith(MockitoExtension.class) public class CrawlerTest { @Mock - private CrawlerSourceConfig sourceConfig; + private AcknowledgementSet acknowledgementSet; @Mock private EnhancedSourceCoordinator coordinator; @@ -60,8 +61,8 @@ public void crawlerConstructionTest() { @Test public void executePartitionTest() { - crawler.executePartition(state, buffer, sourceConfig); - verify(client).executePartition(state, buffer, sourceConfig); + crawler.executePartition(state, buffer, acknowledgementSet); + verify(client).executePartition(state, buffer, acknowledgementSet); } @Test diff --git a/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerSchedulerTest.java b/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerSchedulerTest.java index 8fee799b84..9220fa6e80 100644 --- a/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerSchedulerTest.java +++ b/data-prepper-plugins/saas-source-plugins/source-crawler/src/test/java/org/opensearch/dataprepper/plugins/source/source_crawler/coordination/scheduler/WorkerSchedulerTest.java @@ -4,6 +4,9 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -21,6 +24,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeast; @@ -31,6 +35,7 @@ @ExtendWith(MockitoExtension.class) public class WorkerSchedulerTest { + private final String pluginName = "sampleTestPlugin"; @Mock Buffer> buffer; @Mock @@ -39,14 +44,19 @@ public class WorkerSchedulerTest { private CrawlerSourceConfig sourceConfig; @Mock private Crawler crawler; - + @Mock + private PluginMetrics pluginMetrics; + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + @Mock + private AcknowledgementSet acknowledgementSet; @Mock private SourcePartitionStoreItem sourcePartitionStoreItem; - @Test void testUnableToAcquireLeaderPartition() throws InterruptedException { - WorkerScheduler workerScheduler = new WorkerScheduler(buffer, coordinator, sourceConfig, crawler); + WorkerScheduler workerScheduler = new WorkerScheduler(pluginName, buffer, + coordinator, sourceConfig, crawler, pluginMetrics, acknowledgementSetManager); given(coordinator.acquireAvailablePartition(SaasSourcePartition.PARTITION_TYPE)).willReturn(Optional.empty()); ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -58,12 +68,15 @@ void testUnableToAcquireLeaderPartition() throws InterruptedException { @Test void testLeaderPartitionsCreation() throws InterruptedException { - WorkerScheduler workerScheduler = new WorkerScheduler(buffer, coordinator, sourceConfig, crawler); + WorkerScheduler workerScheduler = new WorkerScheduler(pluginName, buffer, + coordinator, sourceConfig, crawler, pluginMetrics, acknowledgementSetManager); String sourceId = UUID.randomUUID() + "|" + SaasSourcePartition.PARTITION_TYPE; String state = "{\"keyAttributes\":{\"project\":\"project-1\"},\"totalItems\":0,\"loadedItems\":20,\"exportStartTime\":1729391235717,\"itemIds\":[\"GTMS-25\",\"GTMS-24\"]}"; when(sourcePartitionStoreItem.getPartitionProgressState()).thenReturn(state); when(sourcePartitionStoreItem.getSourceIdentifier()).thenReturn(sourceId); + when(sourceConfig.isAcknowledgments()).thenReturn(true); + when(acknowledgementSetManager.create(any(), any())).thenReturn(acknowledgementSet); PartitionFactory factory = new PartitionFactory(); EnhancedSourcePartition sourcePartition = factory.apply(sourcePartitionStoreItem); given(coordinator.acquireAvailablePartition(SaasSourcePartition.PARTITION_TYPE)).willReturn(Optional.of(sourcePartition)); @@ -76,13 +89,14 @@ void testLeaderPartitionsCreation() throws InterruptedException { // Check if crawler was invoked and updated leader lease renewal time SaasWorkerProgressState stateObj = (SaasWorkerProgressState) sourcePartition.getProgressState().get(); - verify(crawler, atLeast(1)).executePartition(stateObj, buffer, sourceConfig); + verify(crawler, atLeast(1)).executePartition(stateObj, buffer, acknowledgementSet); verify(coordinator, atLeast(1)).completePartition(eq(sourcePartition)); } @Test void testEmptyProgressState() throws InterruptedException { - WorkerScheduler workerScheduler = new WorkerScheduler(buffer, coordinator, sourceConfig, crawler); + WorkerScheduler workerScheduler = new WorkerScheduler(pluginName, buffer, + coordinator, sourceConfig, crawler, pluginMetrics, acknowledgementSetManager); String sourceId = UUID.randomUUID() + "|" + SaasSourcePartition.PARTITION_TYPE; when(sourcePartitionStoreItem.getPartitionProgressState()).thenReturn(null); @@ -104,7 +118,8 @@ void testEmptyProgressState() throws InterruptedException { @Test void testExceptionWhileAcquiringWorkerPartition() throws InterruptedException { - WorkerScheduler workerScheduler = new WorkerScheduler(buffer, coordinator, sourceConfig, crawler); + WorkerScheduler workerScheduler = new WorkerScheduler(pluginName, buffer, + coordinator, sourceConfig, crawler, pluginMetrics, acknowledgementSetManager); given(coordinator.acquireAvailablePartition(SaasSourcePartition.PARTITION_TYPE)).willThrow(RuntimeException.class); ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -119,7 +134,8 @@ void testExceptionWhileAcquiringWorkerPartition() throws InterruptedException { @Test void testWhenNoPartitionToWorkOn() throws InterruptedException { - WorkerScheduler workerScheduler = new WorkerScheduler(buffer, coordinator, sourceConfig, crawler); + WorkerScheduler workerScheduler = new WorkerScheduler(pluginName, buffer, + coordinator, sourceConfig, crawler, pluginMetrics, acknowledgementSetManager); ExecutorService executorService = Executors.newSingleThreadExecutor(); executorService.submit(workerScheduler); @@ -134,7 +150,8 @@ void testWhenNoPartitionToWorkOn() throws InterruptedException { @Test void testRetryBackOffTriggeredWhenExceptionOccurred() throws InterruptedException { - WorkerScheduler workerScheduler = new WorkerScheduler(buffer, coordinator, sourceConfig, crawler); + WorkerScheduler workerScheduler = new WorkerScheduler(pluginName, buffer, + coordinator, sourceConfig, crawler, pluginMetrics, acknowledgementSetManager); given(coordinator.acquireAvailablePartition(SaasSourcePartition.PARTITION_TYPE)).willThrow(RuntimeException.class); ExecutorService executorService = Executors.newSingleThreadExecutor();