Skip to content

Commit

Permalink
drop noise rate
Browse files Browse the repository at this point in the history
  • Loading branch information
junkangwu committed Oct 23, 2024
1 parent 6f3abb8 commit 8280272
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
3 changes: 1 addition & 2 deletions config/loss/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
name: sft
mode_weight: 0.0
name: sft
7 changes: 3 additions & 4 deletions preference_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def split_prompt_and_responses(ex):

return data

def get_dataset(name: str, split: str, silent: bool = False, cache_dir: str = None, noise_rate=0.):
def get_dataset(name: str, split: str, silent: bool = False, cache_dir: str = None):
"""Load the given dataset by name. Supported by default are 'shp', 'hh', and 'se'."""
if name == 'shp':
data = get_shp(split, silent=silent, cache_dir=cache_dir)
Expand Down Expand Up @@ -288,8 +288,7 @@ def get_batch_iterator(names: List[str],
n_examples: Optional[int] = None,
seed:int = 0,
silent: bool = False,
cache_dir: Optional[str] = None,
noise_rate=0.) -> Iterator[Dict]:
cache_dir: Optional[str] = None) -> Iterator[Dict]:
"""Get an iterator over batches of data. Stops after n_epochs or n_examples, whichever comes first.
Args:
Expand Down Expand Up @@ -317,7 +316,7 @@ def get_batch_iterator(names: List[str],
flat_data = []
for name in names:
truncation_mode = 'keep_end' if name == 'hh' else 'keep_start'
for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir, noise_rate=noise_rate).items():
for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir).items():
flat_data.append((prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode))

collate_fn = get_collate_fn(tokenizer)
Expand Down
2 changes: 1 addition & 1 deletion trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: st
self.loss_mean = torch.zeros(1, device='cuda')
self.loss_std = torch.zeros(1, device='cuda')

self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs, n_examples=config.n_examples, batch_size=config.batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs), noise_rate=config.noise_rate)
self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs, n_examples=config.n_examples, batch_size=config.batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs))
rank0_print(f'Loaded train data iterator')
self.eval_iterator = get_batch_iterator(**data_iterator_kwargs, split='test', n_examples=config.n_eval_examples, batch_size=config.eval_batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs))
self.eval_batches = list(self.eval_iterator)
Expand Down

0 comments on commit 8280272

Please sign in to comment.