Skip to content

Commit

Permalink
Add optional mask for dalle's edit api (#157) (#169)
Browse files Browse the repository at this point in the history
* Add optional mask for dalle's edit api
  • Loading branch information
YufeiG authored Jan 6, 2023
1 parent a730583 commit 70c2a85
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
9 changes: 5 additions & 4 deletions openai/api_resources/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def acreate_variation(
def _prepare_create_edit(
cls,
image,
mask,
mask=None,
api_key=None,
api_base=None,
api_type=None,
Expand All @@ -179,14 +179,15 @@ def _prepare_create_edit(
for key, value in params.items():
files.append((key, (None, value)))
files.append(("image", ("image", image, "application/octet-stream")))
files.append(("mask", ("mask", mask, "application/octet-stream")))
if mask is not None:
files.append(("mask", ("mask", mask, "application/octet-stream")))
return requestor, url, files

@classmethod
def create_edit(
cls,
image,
mask,
mask=None,
api_key=None,
api_base=None,
api_type=None,
Expand Down Expand Up @@ -215,7 +216,7 @@ def create_edit(
async def acreate_edit(
cls,
image,
mask,
mask=None,
api_key=None,
api_base=None,
api_type=None,
Expand Down
8 changes: 5 additions & 3 deletions openai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ def create_variation(cls, args):
def create_edit(cls, args):
with open(args.image, "rb") as file_reader:
image_reader = BufferReader(file_reader.read(), desc="Upload progress")
with open(args.mask, "rb") as file_reader:
mask_reader = BufferReader(file_reader.read(), desc="Upload progress")
mask_reader = None
if args.mask is not None:
with open(args.mask, "rb") as file_reader:
mask_reader = BufferReader(file_reader.read(), desc="Upload progress")
resp = openai.Image.create_edit(
image=image_reader,
mask=mask_reader,
Expand Down Expand Up @@ -893,7 +895,7 @@ def help(args):
"-M",
"--mask",
type=str,
required=True,
required=False,
help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
)
sub.set_defaults(func=Image.create_edit)
Expand Down

0 comments on commit 70c2a85

Please sign in to comment.