Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go livepeer liveportrait #3200

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func NewFileWorker(files map[string]string) *FileWorker {
return &FileWorker{files: files}
}

func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["text-to-image"]
if !ok {
return nil, errors.New("text-to-image response file not found")
Expand All @@ -36,7 +36,7 @@ func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSON
return &resp, nil
}

func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["image-to-image"]
if !ok {
return nil, errors.New("image-to-image response file not found")
Expand All @@ -55,7 +55,7 @@ func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMu
return &resp, nil
}

func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
fname, ok := w.files["image-to-video"]
if !ok {
return nil, errors.New("image-to-video response file not found")
Expand All @@ -74,7 +74,26 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMu
return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) LivePortrait(ctx context.Context, req worker.LivePortraitLivePortraitPostMultipartRequestBody) (*worker.VideoResponse, error) {
fname, ok := w.files["live-portrait"]
if !ok {
return nil, errors.New("live-portrait response file not found")
}

data, err := os.ReadFile(fname)
if err != nil {
return nil, err
}

var resp worker.VideoResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}

return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
Expand Down
14 changes: 14 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,20 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice)
}
case "live-portrait":
_, ok := capabilityConstraints[core.Capability_LivePortrait]
if !ok {
aiCaps = append(aiCaps, core.Capability_LivePortrait)
capabilityConstraints[core.Capability_LivePortrait] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_LivePortrait].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_LivePortrait, config.ModelID, autoPrice)
}
}

if len(aiCaps) > 0 {
Expand Down
13 changes: 7 additions & 6 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ import (
var errPipelineNotAvailable = errors.New("pipeline not available")

type AI interface {
TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
TextToImage(context.Context, worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
LivePortrait(context.Context, worker.LivePortraitLivePortraitPostMultipartRequestBody) (*worker.VideoResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const (
Capability_Upscale
Capability_AudioToText
Capability_SegmentAnything2
Capability_LivePortrait
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -116,6 +117,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_SegmentAnything2: "Segment anything 2",
Capability_LivePortrait: "Live Portrait",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -207,6 +209,7 @@ func OptionalCapabilities() []Capability {
Capability_Upscale,
Capability_AudioToText,
Capability_SegmentAnything2,
Capability_LivePortrait,
}
}

Expand Down
121 changes: 101 additions & 20 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"crypto/sha256"
"errors"
"fmt"
"io/ioutil"
"image"
"math/big"
"net/url"
"os"
Expand Down Expand Up @@ -110,30 +110,34 @@ func (orch *orchestrator) TranscoderResults(tcID int64, res *RemoteTranscoderRes
orch.node.TranscoderManager.transcoderResults(tcID, res)
}

func (orch *orchestrator) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return orch.node.textToImage(ctx, req)
}

func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToImage(ctx, req)
}

func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToVideo(ctx, req)
}

func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.upscale(ctx, req)
}

func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return orch.node.AudioToText(ctx, req)
}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return orch.node.SegmentAnything2(ctx, req)
}

func (orch *orchestrator) LivePortrait(ctx context.Context, req worker.LivePortraitLivePortraitPostMultipartRequestBody) (*worker.VideoResponse, error) {
return orch.node.LivePortrait(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down Expand Up @@ -630,19 +634,24 @@ func (n *LivepeerNode) transcodeFrames(ctx context.Context, sessionID string, ur

// We only support base64 png data urls right now
// We will want to support HTTP and file urls later on as well
dirPath := path.Join(n.WorkDir, "input", sessionID+"_"+string(RandomManifestID()))
dirPath := path.Join(WorkDir, "input", sessionID+"_"+string(RandomManifestID()))
JJassonn69 marked this conversation as resolved.
Show resolved Hide resolved
fnamep = &dirPath
if err := os.MkdirAll(dirPath, 0700); err != nil {
clog.Errorf(ctx, "Transcoder cannot create frames dir err=%q", err)
return terr(err)
}
var wg sync.WaitGroup // Add a WaitGroup to wait for all goroutines to finish
for i, url := range urls {
fname := path.Join(dirPath, strconv.Itoa(i)+".png")
if err := worker.SaveImageB64DataUrl(url, fname); err != nil {
clog.Errorf(ctx, "Transcoder failed to save image from url err=%q", err)
return terr(err)
}
wg.Add(1) // Increment the WaitGroup counter
go func(i int, url string) {
defer wg.Done() // Decrement the counter when the goroutine completes
fname := path.Join(dirPath, strconv.Itoa(i)+".png")
if err := worker.SaveImageB64DataUrl(url, fname); err != nil {
clog.Errorf(ctx, "Transcoder failed to save image from url err=%q", err)
}
}(i, url) // Pass the loop variables to the goroutine
}
wg.Wait()

// Use local software transcoder instead of node's configured transcoder
// because if the node is using a nvidia transcoder there may be sporadic
Expand Down Expand Up @@ -750,7 +759,7 @@ func (n *LivepeerNode) transcodeSeg(ctx context.Context, config transcodeConfig,
// Create input file from segment. Removed after claiming complete or error
fname := path.Join(n.WorkDir, inName)
fnamep = &fname
if err := ioutil.WriteFile(fname, seg.Data, 0644); err != nil {
if err := os.WriteFile(fname, seg.Data, 0644); err != nil {
clog.Errorf(ctx, "Transcoder cannot write file err=%q", err)
return terr(err)
}
Expand Down Expand Up @@ -951,27 +960,27 @@ func (n *LivepeerNode) serveTranscoder(stream net.Transcoder_RegisterTranscoderS
}
}

func (n *LivepeerNode) textToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) textToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.TextToImage(ctx, req)
}

func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.ImageToImage(ctx, req)
}

func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.Upscale(ctx, req)
}

func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return n.AIWorker.AudioToText(ctx, req)
}

func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return n.AIWorker.SegmentAnything2(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
numVideos := 1

Expand Down Expand Up @@ -1051,6 +1060,78 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideo
return &worker.ImageResponse{Images: videos}, nil
}

func (n *LivepeerNode) LivePortrait(ctx context.Context, req worker.LivePortraitLivePortraitPostMultipartRequestBody) (*worker.VideoResponse, error) {
// handle frames from api
start := time.Now()
resp, err := n.AIWorker.LivePortrait(ctx, req)
if err != nil {
return nil, err
}

took := time.Since(start)
clog.V(common.DEBUG).Infof(ctx, "Animating the video took=%v", took)

sessionID := string(RandomManifestID())
framerate := 30

// Find the resolution of the source image
sourceImage, _ := req.SourceImage.Bytes() // Assuming req has SourceImage field
img, _, err := image.Decode(bytes.NewReader(sourceImage)) // Decode the image
if err != nil {
return nil, fmt.Errorf("failed to decode source image: %v", err)
}
bounds := img.Bounds()
width, height := bounds.Dx(), bounds.Dy() // Get width and height

inProfile := ffmpeg.VideoProfile{
Framerate: uint(framerate),
FramerateDen: 1,
}
outProfile := ffmpeg.VideoProfile{
Name: "live-portrait",
Framerate: uint(framerate),
Bitrate: "6000k",
Resolution: fmt.Sprintf("%vx%v", width, height), // Set resolution for outProfile
Format: ffmpeg.FormatMP4,
}

// Transcode frames into segments.
videos := make([]worker.Media, len(resp.Frames))
for i, batch := range resp.Frames {
// Create slice of frame urls for a batch
urls := make([]string, len(batch))
for j, frame := range batch {
urls[j] = frame.Url
}

// Transcode slice of frame urls into a segment
res := n.transcodeFrames(ctx, sessionID, urls, inProfile, outProfile)
if res.Err != nil {
return nil, res.Err
}

// Assume only single rendition right now
seg := res.TranscodeData.Segments[0]
name := fmt.Sprintf("%v.mp4", RandomManifestID())
segData := bytes.NewReader(seg.Data)
uri, err := res.OS.SaveData(ctx, name, segData, nil, 0)
if err != nil {
return nil, err
}

videos[i] = worker.Media{
Url: uri,
}

// NOTE: Seed is consistent for video; NSFW check applies to first frame only.
if len(batch) > 0 {
videos[i].Nsfw = batch[0].Nsfw
videos[i].Seed = batch[0].Seed
}
}
return &worker.VideoResponse{Frames: [][]worker.Media{videos}}, nil
}

func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) {
remoteChan, err := rtm.getTaskChan(tcID)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion core/transcoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ func (lt *LocalTranscoder) Transcode(ctx context.Context, md *SegTranscodingMeta
// Returns UnrecoverableError instead of panicking to gracefully notify orchestrator about transcoder's failure
defer recoverFromPanic(&retErr)

// Set up in / out config
in := &ffmpeg.TranscodeOptionsIn{
Fname: md.Fname,
Accel: ffmpeg.Software,
Expand Down
Loading