Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative for managing task array status in Google Batch #5723

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -463,18 +463,16 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
}

protected String getStateFromTaskStatus() {
final tasks = client.listTasks(jobId)
if( !tasks.iterator().hasNext() ) {
return getStateFromJobStatus()
}
final now = System.currentTimeMillis()
final delta = now - timestamp;
if( !taskState || delta >= 1_000) {
try {
final status = client.getTaskStatus(jobId, taskId)
final status = client.getTaskInArrayStatus(jobId, taskId)
if( status ) {
inspectTaskStatus(status)
}catch (NotFoundException e) {
manageNotFound(tasks)
} else {
// If no task status retrieved check job status
final jobStatus = client.getJobStatus(jobId)
inspectJobStatus(jobStatus)
}
}
return taskState
Expand Down Expand Up @@ -505,20 +503,6 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
}
}

protected String manageNotFound( Iterable<Task> tasks) {
// If task is array, check if the in the task list
for (Task t in tasks) {
if (t.name == client.generateTaskName(jobId, taskId)) {
inspectTaskStatus(t.status)
return taskState
}
}
// if not array or it task is not in the list, check job status.
final status = client.getJobStatus(jobId)
inspectJobStatus(status)
return taskState
}

protected String inspectJobStatus(JobStatus status) {
final newState = status?.state as String
if (newState) {
Expand Down Expand Up @@ -571,6 +555,8 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
task.stderr = errorFile
}
status = TaskStatus.COMPLETED
if( task.isChild )
client.removeFromArrayTasks(jobId, taskId)
return true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ import groovy.util.logging.Slf4j
@Slf4j
@CompileStatic
class BatchClient {

private static long TASK_STATE_INVALID_TIME = 1_000
pditommaso marked this conversation as resolved.
Show resolved Hide resolved
protected String projectId
protected String location
protected BatchServiceClient batchServiceClient
protected BatchConfig config
private Map<String, TaskStatusRecord> arrayTaskStatus = new HashMap<String, TaskStatusRecord>()

BatchClient(BatchConfig config) {
this.config = config
Expand Down Expand Up @@ -198,4 +199,40 @@ class BatchClient {
// apply the action with
return Failsafe.with(policy).get(action)
}


TaskStatus getTaskInArrayStatus(String jobId, String taskId) {
final taskName = generateTaskName(jobId,taskId)
final now = System.currentTimeMillis()
TaskStatusRecord record = arrayTaskStatus.get(taskName)
if( !record || now - record.timestamp > TASK_STATE_INVALID_TIME ){
log.debug("[GOOGLE BATCH] Updating tasks status for job $jobId")
updateArrayTasks(jobId, now)
record = arrayTaskStatus.get(taskName)
}
return record?.status
}

private void updateArrayTasks(String jobId, long now){
for( Task t: listTasks(jobId) ){
arrayTaskStatus.put(t.name, new TaskStatusRecord(t.status, now))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there can be any race condition on accessing the arrayTaskStatus map?

}

pditommaso marked this conversation as resolved.
Show resolved Hide resolved
}

void removeFromArrayTasks(String jobId, String taskId){
final taskName = generateTaskName(jobId,taskId)
TaskStatusRecord record = arrayTaskStatus.remove(taskName)
}
}

@CompileStatic
class TaskStatusRecord {
protected TaskStatus status
protected long timestamp

TaskStatusRecord(TaskStatus status, long timestamp) {
this.status = status
this.timestamp = timestamp
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,16 @@ class GoogleBatchTaskHandlerTest extends Specification {

}

TaskStatus makeTaskStatus(String desc) {
TaskStatus.newBuilder()
.addStatusEvents(
TaskStatus makeTaskStatus(TaskStatus.State state, String desc) {
def builder = TaskStatus.newBuilder()
if (state)
builder.setState(state)
if (desc)
builder.addStatusEvents(
StatusEvent.newBuilder()
.setDescription(desc)
)
.build()
builder.build()
}

def 'should detect spot failures from status event'() {
Expand All @@ -486,8 +489,8 @@ class GoogleBatchTaskHandlerTest extends Specification {

when:
client.getTaskStatus(jobId, taskId) >>> [
makeTaskStatus('Task failed due to Spot VM preemption with exit code 50001.'),
makeTaskStatus('Task succeeded')
makeTaskStatus(null,'Task failed due to Spot VM preemption with exit code 50001.'),
makeTaskStatus(null, 'Task succeeded')
]
then:
handler.getJobError().message == "Task failed due to Spot VM preemption with exit code 50001."
Expand Down Expand Up @@ -639,15 +642,15 @@ class GoogleBatchTaskHandlerTest extends Specification {
client.generateTaskName(jobId, taskId) >> "$jobId/group0/$taskId"
//Force errors
client.getTaskStatus(jobId, taskId) >> { throw new NotFoundException(new Exception("Error"), GrpcStatusCode.of(Status.Code.NOT_FOUND), false) }
client.listTasks(jobId) >> TASK_LIST
client.getTaskInArrayStatus(jobId, taskId) >> TASK_STATUS
client.getJobStatus(jobId) >> makeJobStatus(JOB_STATUS, "")
then:
handler.getTaskState() == EXPECTED

where:
EXPECTED | JOB_STATUS | TASK_LIST
"FAILED" | JobStatus.State.FAILED | {[ makeTask("1/group0/2", TaskStatus.State.PENDING), makeTask("1/group0/3", TaskStatus.State.PENDING) ].iterator() } // Task not in the list, get from job
"SUCCEEDED" | JobStatus.State.FAILED | {[ makeTask("1/group0/1", TaskStatus.State.SUCCEEDED), makeTask("1/group0/2", TaskStatus.State.PENDING)].iterator() } //Task in the list, get from task status
EXPECTED | JOB_STATUS | TASK_STATUS
"FAILED" | JobStatus.State.FAILED | null // Task not in the list, get from job
"SUCCEEDED" | JobStatus.State.FAILED | makeTaskStatus(TaskStatus.State.SUCCEEDED, "") // get from task status
}

def makeTask(String name, TaskStatus.State state){
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2013-2024, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package nextflow.cloud.google.batch.client

import com.google.cloud.batch.v1.Task
import com.google.cloud.batch.v1.TaskName
import com.google.cloud.batch.v1.TaskStatus
import spock.lang.Specification

/**
*
* @author Jorge Ejarque <[email protected]>
*/
class BatchClientTest extends Specification{



def 'should return task status with getTaskInArray' () {
given:
def project = 'project-id'
def location = 'location-id'
def job1 = 'job1-id'
def task1 = 'task1-id'
def task1Name = TaskName.of(project, location, job1, 'group0', task1).toString()
def job2 = 'job2-id'
def task2 = 'task2-id'
def task2Name = TaskName.of(project, location, job2, 'group0', task2).toString()
def job3 = 'job3-id'
def task3 = 'task3-id'
def task3Name = TaskName.of(project, location, job3, 'group0', task3).toString()
def now = System.currentTimeMillis()
def arrayTasks = new HashMap<String,TaskStatusRecord>()
def client = Spy( new BatchClient( projectId: project, location: location, arrayTaskStatus: arrayTasks ) )

when:
client.listTasks(job2) >> {
def list = new LinkedList<>()
list.add(makeTask(task2Name, TaskStatus.State.FAILED))
return list
}
client.listTasks(job3) >> {
def list = new LinkedList<>()
list.add(makeTask(task3Name, TaskStatus.State.SUCCEEDED))
return list
}
arrayTasks.put(task1Name, makeTaskStatusRecord(TaskStatus.State.RUNNING, System.currentTimeMillis()))
arrayTasks.put(task2Name, makeTaskStatusRecord(TaskStatus.State.PENDING, System.currentTimeMillis() - 1_001))

then:
// recent cached task
client.getTaskInArrayStatus(job1, task1).state == TaskStatus.State.RUNNING
// Outdated cached task
client.getTaskInArrayStatus(job2, task2).state == TaskStatus.State.FAILED
// no cached task
client.getTaskInArrayStatus(job3, task3).state == TaskStatus.State.SUCCEEDED
}

def TaskStatusRecord makeTaskStatusRecord(TaskStatus.State state, long timestamp) {
return new TaskStatusRecord(TaskStatus.newBuilder().setState(state).build(), timestamp)

}

def makeTask(String name, TaskStatus.State state){
Task.newBuilder().setName(name)
.setStatus(TaskStatus.newBuilder().setState(state).build())
.build()

}

}