Skip to content

Commit

Permalink
Multiprocessing text generation
Browse files Browse the repository at this point in the history
  • Loading branch information
eleurent committed Nov 15, 2023
1 parent b56a9a3 commit 049a28a
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions generation/text/generate_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dotenv import load_dotenv
import openai
import wikipedia
import functools
import multiprocessing

sys.path.append('..')
import nature_go_client
Expand All @@ -20,6 +22,29 @@ def get_species(client, batch_size=5, ordering=None):
print(f'Found {len(species_list)} species')
return pd.DataFrame(species_list)

def generate(species, client):
print('#############################')
print(f'Starting generation for {species.scientificNameWithoutAuthor}')
try:
summary = summary_generation.generate_summaries(common_name=species.display_name, scientific_name=species.scientificNameWithoutAuthor, material=None, prompt=summary_prompt.summary_v7)
except wikipedia.PageError as e:
print(e)
return
if not all(f'part_{i}' in summary.keys() for i in range(1, 4)):
print(f'Problem with summary generation for {species.scientificNameWithoutAuthor}: keys: {list(summary.keys())}.')
return
print(f'Generated summaries for {species.scientificNameWithoutAuthor}.')
client.update_species_field(species.id, 'descriptions', [summary['part_1'], summary['part_2'], summary['part_3']])
print('Uploaded summaries.')
material = ' '.join([summary['part_1'], summary['part_2'], summary['part_3']])
questions = question_generation.generate_questions(common_name=species.display_name, scientific_name=species.scientificNameWithoutAuthor, material=material)
if not questions:
print(f'Problem with question generation for {species.scientificNameWithoutAuthor}: {questions}.')
return
print(f'Generated questions for {species.scientificNameWithoutAuthor}.')
client.post_species_questions(species.id, questions)
print('Uploaded questions.')


def main(args):
load_dotenv()
Expand All @@ -31,28 +56,10 @@ def main(args):

while True:
species_batch = get_species(client, batch_size=args.batch_size, ordering=args.ordering)
print('#############################')
for (_, species) in species_batch.iterrows():
print(f'Starting generation for {species.scientificNameWithoutAuthor}')
try:
summary = summary_generation.generate_summaries(common_name=species.display_name, scientific_name=species.scientificNameWithoutAuthor, material=None, prompt=summary_prompt.summary_v7)
except wikipedia.PageError as e:
print(e)
continue
if not all(f'part_{i}' in summary.keys() for i in range(1, 4)):
print(f'Problem with summary generation for {species.scientificNameWithoutAuthor}: keys: {list(summary.keys())}.')
continue
print(f'Generated summaries for {species.scientificNameWithoutAuthor}.')
client.update_species_field(species.id, 'descriptions', [summary['part_1'], summary['part_2'], summary['part_3']])
print('Uploaded summaries.')
material = ' '.join([summary['part_1'], summary['part_2'], summary['part_3']])
questions = question_generation.generate_questions(common_name=species.display_name, scientific_name=species.scientificNameWithoutAuthor, material=material)
if not questions:
print(f'Problem with question generation for {species.scientificNameWithoutAuthor}: {questions}.')
continue
print(f'Generated questions for {species.scientificNameWithoutAuthor}.')
client.post_species_questions(species.id, questions)
print('Uploaded questions.')
species_list = [species for (_, species) in species_batch.iterrows()]
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count() - 1)
generate_and_upload = functools.partial(generate, client=client)
pool.map(generate_and_upload, species_list)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit 049a28a

Please sign in to comment.