From bc218ad187eafdb22e251e51d8c124ea52a72861 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 25 Jul 2023 19:02:53 +0000
Subject: [PATCH 1/2] added dolphin mix

---
 .../custom_datasets/__init__.py               |  4 +-
 .../custom_datasets/prompt_dialogue.py        | 59 +++++++++++++++++++
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py
index 740fae8563..3043a5db12 100644
--- a/model/model_training/custom_datasets/__init__.py
+++ b/model/model_training/custom_datasets/__init__.py
@@ -8,7 +8,7 @@
 from model_training.custom_datasets.instruction import INSTRUCTION_DATASETS, InstructionDataset
 from model_training.custom_datasets.oasst_dataset import load_oasst_export
 from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama
-from model_training.custom_datasets.prompt_dialogue import Gpt4All, OrcaChat, load_oig_file
+from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file
 from model_training.custom_datasets.qa_datasets import (
     SODA,
     AlpacaGpt4,
@@ -176,6 +176,8 @@ def get_one_dataset(
         dataset = GPTeacher_Roleplay(cache_dir=data_path, mode=mode, **kwargs)
     elif dataset_name == "orca-chat":
         dataset = OrcaChat(cache_dir=data_path, **kwargs)
+    elif dataset_name == "dolphin-mix":
+        dataset = DolphinMix(cache_dir=data_path, **kwargs)
     else:
         raise ValueError(f"Unknown dataset {dataset_name}")
 
diff --git a/model/model_training/custom_datasets/prompt_dialogue.py b/model/model_training/custom_datasets/prompt_dialogue.py
index 2dda25333b..2b959e2a44 100644
--- a/model/model_training/custom_datasets/prompt_dialogue.py
+++ b/model/model_training/custom_datasets/prompt_dialogue.py
@@ -4,6 +4,7 @@
 from pathlib import Path
 from typing import List, Optional, Union
 
+import numpy as np
 import requests
 from datasets import load_dataset
 from model_training.custom_datasets.formatting import DatasetEntrySft, Role, Utterance
@@ -193,3 +194,61 @@ def __getitem__(self, idx):
         ]
 
         return DatasetEntrySft(conversation=conv_utt, system_message=instruction)
+
+
+class DolphinMix(Dataset):
+    name = "dophin-mix"
+
+    def __init__(self, cache_dir, num_samples=100000, max_char_len=8000):
+        self.dataset = load_dataset(
+            "ehartford/dolphin", data_files="flan5m-alpaca-uncensored.jsonl", cache_dir=cache_dir
+        )
+        self.dataset = self.dataset["train"].shuffle(42).select(range(num_samples))
+        self.max_char_len = max_char_len
+        instructions = set([item["instruction"] for item in self.dataset])
+
+        self.conversations = []
+        for inst in instructions:
+            data_sample = self.dataset.filter(lambda example: example["instruction"] == inst)
+            available_indices = np.arange(0, len(data_sample)).tolist()
+            removed_indices = []
+            for idx in available_indices:
+                conversation_len = len(inst)
+                if idx not in removed_indices and conversation_len < self.max_char_len:
+                    conversation = {"conversation": []}
+                    conversation["instruction"] = inst
+                    input, output = [data_sample[idx][key] for key in ("input", "output")]
+                    conversation["conversation"].append({"input": input, "output": output})
+                    conversation_len += len(input) + len(output)
+                    removed_indices.append(idx)
+                    while conversation_len < self.max_char_len:
+                        indices_to_pick = np.setdiff1d(available_indices, removed_indices)
+                        if len(indices_to_pick) > 0:
+                            idx = np.random.choice(indices_to_pick, size=1)[0]
+                            input, output = [data_sample[int(idx)][key] for key in ("input", "output")]
+                            conversation["conversation"].append({"input": input, "output": output})
+                            conversation_len += len(input) + len(output)
+                            removed_indices.append(idx)
+                        else:
+                            break
+
+                    self.conversations.append(conversation)
+
+    def __len__(self):
+        return len(self.conversations)
+
+    def __getitem__(self, idx):
+        conversation, instruction = [self.conversations[idx][key] for key in ("conversation", "instruction")]
+        conversation = [(item["input"], item["output"]) for item in conversation]
+        conversation = list(sum(conversation, ()))
+        conv_utt: list[Utterance] = [
+            (
+                Utterance(
+                    text=conv,
+                    role=Role.prompter if i % 2 == 0 else Role.assistant,
+                )
+            )
+            for i, conv in enumerate(conversation)
+        ]
+
+        return DatasetEntrySft(conversation=conv_utt, system_message=instruction)

From b7a4fd2e5d880f6d09628772da61f97ccf6c1c12 Mon Sep 17 00:00:00 2001
From: Shahules786 <shahules786@gmail.com>
Date: Tue, 25 Jul 2023 19:10:56 +0000
Subject: [PATCH 2/2] config seed

---
 model/model_training/custom_datasets/prompt_dialogue.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/model/model_training/custom_datasets/prompt_dialogue.py b/model/model_training/custom_datasets/prompt_dialogue.py
index 2b959e2a44..476e82b979 100644
--- a/model/model_training/custom_datasets/prompt_dialogue.py
+++ b/model/model_training/custom_datasets/prompt_dialogue.py
@@ -199,11 +199,11 @@ def __getitem__(self, idx):
 class DolphinMix(Dataset):
     name = "dophin-mix"
 
-    def __init__(self, cache_dir, num_samples=100000, max_char_len=8000):
+    def __init__(self, cache_dir, num_samples=100000, max_char_len=8000, seed=42):
         self.dataset = load_dataset(
             "ehartford/dolphin", data_files="flan5m-alpaca-uncensored.jsonl", cache_dir=cache_dir
         )
-        self.dataset = self.dataset["train"].shuffle(42).select(range(num_samples))
+        self.dataset = self.dataset["train"].shuffle(seed).select(range(num_samples))
         self.max_char_len = max_char_len
         instructions = set([item["instruction"] for item in self.dataset])