Skip to content

Commit

Permalink
Dolphin gpt 3.5 data mix (#3606)
Browse files Browse the repository at this point in the history
- Added dolphin random data mix to form conversations from gpt 3.5 file.
- Instructions of the same type are only considered while picking at
random to form conversation
- Also ensured that same samples are not considered more than once 

Configure
```
- dolphin-mix
        num_samples: 100000
        max_char_len: 32000
        seed: 44
```
  • Loading branch information
shahules786 authored Jul 25, 2023
1 parent bc5b70d commit c2f444d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
4 changes: 3 additions & 1 deletion model/model_training/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from model_training.custom_datasets.instruction import INSTRUCTION_DATASETS, InstructionDataset
from model_training.custom_datasets.oasst_dataset import load_oasst_export
from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama
from model_training.custom_datasets.prompt_dialogue import Gpt4All, OrcaChat, load_oig_file
from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file
from model_training.custom_datasets.qa_datasets import (
SODA,
AlpacaGpt4,
Expand Down Expand Up @@ -176,6 +176,8 @@ def get_one_dataset(
dataset = GPTeacher_Roleplay(cache_dir=data_path, mode=mode, **kwargs)
elif dataset_name == "orca-chat":
dataset = OrcaChat(cache_dir=data_path, **kwargs)
elif dataset_name == "dolphin-mix":
dataset = DolphinMix(cache_dir=data_path, **kwargs)
else:
raise ValueError(f"Unknown dataset {dataset_name}")

Expand Down
59 changes: 59 additions & 0 deletions model/model_training/custom_datasets/prompt_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import List, Optional, Union

import numpy as np
import requests
from datasets import load_dataset
from model_training.custom_datasets.formatting import DatasetEntrySft, Role, Utterance
Expand Down Expand Up @@ -193,3 +194,61 @@ def __getitem__(self, idx):
]

return DatasetEntrySft(conversation=conv_utt, system_message=instruction)


class DolphinMix(Dataset):
name = "dophin-mix"

def __init__(self, cache_dir, num_samples=100000, max_char_len=8000, seed=42):
self.dataset = load_dataset(
"ehartford/dolphin", data_files="flan5m-alpaca-uncensored.jsonl", cache_dir=cache_dir
)
self.dataset = self.dataset["train"].shuffle(seed).select(range(num_samples))
self.max_char_len = max_char_len
instructions = set([item["instruction"] for item in self.dataset])

self.conversations = []
for inst in instructions:
data_sample = self.dataset.filter(lambda example: example["instruction"] == inst)
available_indices = np.arange(0, len(data_sample)).tolist()
removed_indices = []
for idx in available_indices:
conversation_len = len(inst)
if idx not in removed_indices and conversation_len < self.max_char_len:
conversation = {"conversation": []}
conversation["instruction"] = inst
input, output = [data_sample[idx][key] for key in ("input", "output")]
conversation["conversation"].append({"input": input, "output": output})
conversation_len += len(input) + len(output)
removed_indices.append(idx)
while conversation_len < self.max_char_len:
indices_to_pick = np.setdiff1d(available_indices, removed_indices)
if len(indices_to_pick) > 0:
idx = np.random.choice(indices_to_pick, size=1)[0]
input, output = [data_sample[int(idx)][key] for key in ("input", "output")]
conversation["conversation"].append({"input": input, "output": output})
conversation_len += len(input) + len(output)
removed_indices.append(idx)
else:
break

self.conversations.append(conversation)

def __len__(self):
return len(self.conversations)

def __getitem__(self, idx):
conversation, instruction = [self.conversations[idx][key] for key in ("conversation", "instruction")]
conversation = [(item["input"], item["output"]) for item in conversation]
conversation = list(sum(conversation, ()))
conv_utt: list[Utterance] = [
(
Utterance(
text=conv,
role=Role.prompter if i % 2 == 0 else Role.assistant,
)
)
for i, conv in enumerate(conversation)
]

return DatasetEntrySft(conversation=conv_utt, system_message=instruction)

0 comments on commit c2f444d

Please sign in to comment.