diff --git a/projects/bb3/agents/opt_api_agent.py b/projects/bb3/agents/opt_api_agent.py index 6a366c59afc..d8519dc87e4 100644 --- a/projects/bb3/agents/opt_api_agent.py +++ b/projects/bb3/agents/opt_api_agent.py @@ -599,6 +599,12 @@ def add_cmdline_args( help='if a generation is returned with a newline character, set True to take last newline. ' 'set False to take first new line.', ) + parser.add_argument( + '--generation-allow-newline', + type='bool', + default=False, + help='if a generation is returned with a newline character, set True to return all generated tokens.', + ) parser.add_argument( '--memory-decision-use-memories', type='bool', @@ -713,7 +719,7 @@ def get_gen_results(self, observations: List[Message], **gen_params): if not APIUtils.is_request_failed_response(r): r['choices'][0]['text'] = r['choices'][0]['text'].strip("\n") - if any( + if not self.opt.get('generation_allow_newline') and any( '\n' in res['choices'][0]['text'] for res in results if not APIUtils.is_request_failed_response(res) diff --git a/projects/bb3/tests/opt_presets.py b/projects/bb3/tests/opt_presets.py index 172e12a840a..8687586daf5 100644 --- a/projects/bb3/tests/opt_presets.py +++ b/projects/bb3/tests/opt_presets.py @@ -255,6 +255,20 @@ "grm_partner_prefix": "Partner", "orm_partner_prefix": "Partner", "brm_partner_prefix": "Partner", + "sdm_generation_allow_newline": False, + "mdm_generation_allow_newline": False, + "sgm_generation_allow_newline": False, + "mgm_generation_allow_newline": False, + "mkm_generation_allow_newline": False, + "ckm_generation_allow_newline": False, + "skm_generation_allow_newline": False, + "srm_generation_allow_newline": False, + "crm_generation_allow_newline": False, + "mrm_generation_allow_newline": False, + "vrm_generation_allow_newline": False, + "grm_generation_allow_newline": False, + "orm_generation_allow_newline": False, + "brm_generation_allow_newline": False, "self_prefix": "Person 2", "self_memory_prefix": "Person 2's Persona", "partner_prefix": "Person 1",