-
Notifications
You must be signed in to change notification settings - Fork 134
/
Copy pathGRPO.py
897 lines (737 loc) · 37.1 KB
/
GRPO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
import numpy as np
import random
import torch
import torch.nn.functional as F
import copy
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
def set_random_seed(seed: int = 42):
"""
Set the random seed for reproducibility across Python, NumPy, and PyTorch.
Parameters:
seed (int): The seed value to use.
"""
# Set the seed for Python's built-in random module
random.seed(seed)
# Set the seed for NumPy
np.random.seed(seed)
# Set the seed for PyTorch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Ensure deterministic behavior in cuDNN (may impact performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_random_seed(42)
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def prepare_dataset(split="train"):
"""Load and prepare the GSM8K dataset for training with string prompts."""
data = load_dataset('openai/gsm8k', 'main')[split]
formatted_data = []
for example in data:
# Convert list of messages to a single string prompt.
prompt_str = build_prompt([
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["question"]}
])
formatted_example = {
"prompt": prompt_str, # Now a string rather than a list.
"answer": extract_answer_from_dataset(example["answer"])
}
formatted_data.append(formatted_example)
return formatted_data
def build_prompt(messages):
"""
Build a single prompt string from a list of messages.
Each message is expected to be a dictionary with 'role' and 'content' keys.
This function concatenates all message contents, preserving the training format.
"""
return "\n".join([msg["content"].strip() for msg in messages])
def extract_answer_from_model_output(text):
"""
Extracts the value from the last <answer> tag in the text.
Returns None if no valid answer is found.
"""
# Split on <answer> and take everything after the last occurrence
parts = text.split("<answer>")
if len(parts) < 2: # No <answer> tag found
return None
last_part = parts[-1]
# Extract content up to </answer>
if "</answer>" not in last_part:
return None
answer = last_part.split("</answer>")[0].strip()
return None if answer == "..." else answer
def extract_answer_from_dataset(text):
"""
Extracts the answer from the dataset.
The dataset separates the answer using the '####' delimiter.
"""
if "####" not in text:
return None
return text.split("####")[1].strip()
def _extract_last_number(text):
"""
Extracts the last number from text if it's properly separated.
Args:
text (str): The text to extract a number from.
Returns:
float or None: The extracted number as a float, or None if no valid number is found.
Explanation:
1. First removes $ and % signs from the text.
2. Uses regex to find numbers that are:
- Preceded by space, equals sign, or start of string
- Followed by end of string or space
3. Returns the first matching number as a float, or None if no match is found.
"""
import re
# Remove $ and % signs
text = text.replace('$', '').replace('%', '')
# Look for numbers that are:
# - preceded by space or = or start of string (via \b or ^)
# - followed by end of string or space
pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$'
match = re.search(pattern, text)
return float(match.group(1)) if match else None
def _extract_single_number(text):
"""
Extracts a single number from text if exactly one exists.
Args:
text (str): The text to extract a number from.
Returns:
float or None: The extracted number as a float if exactly one number exists,
otherwise None.
Explanation:
1. Uses regex to find all numbers in the text.
2. Returns the first number as a float if exactly one number is found.
3. Returns None if zero or multiple numbers are found.
"""
import re
numbers = re.findall(r'-?\d*\.?\d+', text)
return float(numbers[0]) if len(numbers) == 1 else None
def evaluate_model(model, tokenizer, eval_examples, device):
"""
Evaluates the model on a set of examples and prints detailed results.
Args:
model: The language model to evaluate.
tokenizer: The tokenizer for encoding inputs and decoding outputs.
eval_examples (list): List of evaluation examples, each containing "prompt" and "answer".
device: The device (CPU or GPU) to run evaluation on.
Returns:
float: The accuracy percentage (correct predictions / total examples * 100).
Explanation:
1. Sets the model to evaluation mode.
2. For each example in the evaluation set:
- Encodes the prompt and generates a response using the model.
- Extracts the predicted answer from the generated response.
- Compares the predicted answer with the expected answer using multiple methods:
a. Exact string matching
b. Single number extraction and comparison
c. Last number extraction and comparison
- Prints detailed information about each example.
3. Calculates and returns the overall accuracy.
4. Returns the model to training mode.
"""
model.eval()
correct = 0
total = len(eval_examples)
print("\n" + "="*50)
print("EVALUATION ON", total, "EXAMPLES")
print("="*50)
for example in eval_examples:
# Build the full prompt using the same method as training.
full_prompt = example["prompt"]
expected = example["answer"]
# Tokenize the full prompt and generate a response from the model.
inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
forced_eos_token_id=tokenizer.eos_token_id,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the predicted answer from the model output.
try:
predicted = extract_answer_from_model_output(response)
# Check correctness in multiple ways
if predicted == expected: # First try exact match
is_correct = True
else:
# Try single number
pred_num = _extract_single_number(str(predicted))
exp_num = _extract_single_number(str(expected))
if pred_num is not None and exp_num is not None and pred_num == exp_num:
is_correct = True
else:
# Try last number
pred_num = _extract_last_number(str(predicted))
exp_num = _extract_last_number(str(expected))
is_correct = (pred_num is not None and exp_num is not None and
pred_num == exp_num)
if is_correct:
correct += 1
# Print details of the evaluation.
print("\nPrompt:")
print(full_prompt)
print("\nExpected Answer:")
print(expected)
print("\nExtracted Answer:")
print(predicted)
print("\nFull Generated Response:")
print(response)
print("\nCorrect:", "✓" if is_correct else "✗")
print("-"*50)
except Exception as e:
print("\nFailed to parse model output for prompt:")
print(full_prompt)
print("Error:", e)
print("-"*50)
accuracy = (correct / total) * 100
print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})")
print("="*50)
model.train()
return accuracy
def correctness_reward(prompts, completions, answer, **kwargs):
"""
Assigns a reward based on the correctness of the model's answer.
Args:
prompts (list[str]): List of prompt texts.
completions (list[list[dict]]): List of completion dictionaries.
answer (list[str]): List of expected answers.
**kwargs: Additional keyword arguments.
Returns:
list[float]: Reward scores based on answer correctness.
Explanation:
1. Extracts the text content from each completion.
2. Processes each response to extract the answer portion.
3. Compares extracted answers with expected answers using two methods:
- Exact string matching (2.0 points)
- Numeric equivalence check (1.5 points)
4. Returns a list of reward scores.
"""
# Extract the content from each completion's first element
responses = [completion[0]['content'] for completion in completions]
# Extract answers from model outputs
extracted = [extract_answer_from_model_output(r) for r in responses]
rewards = []
for r, a in zip(extracted, answer):
if r == a: # Exact match case
rewards.append(2.0)
else:
# Try numeric equivalence
r_num = _extract_single_number(str(r))
a_num = _extract_single_number(str(a))
if r_num is not None and a_num is not None and r_num == a_num:
rewards.append(1.5)
else:
rewards.append(0.0)
# Log completion lengths
completion_lengths = [len(response.split()) for response in responses]
return rewards
def format_reward(completions, **kwargs):
"""
Assigns a reward for adhering to the desired XML format.
Args:
completions (list[list[dict]]): List of completion dictionaries.
**kwargs: Additional keyword arguments.
Returns:
list[float]: Reward scores based on format compliance.
Explanation:
1. Extracts the text content from each completion.
2. Assigns points based on the presence of required XML tags:
- 0.2 points for opening <reasoning> tag
- 0.2 points for closing </reasoning> tag
- 0.2 points for opening <answer> tag
- 0.2 points for closing </answer> tag
3. Returns a list of format compliance scores.
"""
# Extract the content from each completion's first element
responses = [completion[0]['content'] for completion in completions]
rewards = []
format_scores = []
for response in responses:
score = 0.0
if "<reasoning>" in response: score += 0.2
if "</reasoning>" in response: score += 0.2
if "<answer>" in response: score += 0.2
if "</answer>" in response: score += 0.2
rewards.append(score)
format_scores.append(score)
return rewards
def combined_reward(prompts, completions, answer):
"""
Combines correctness and format rewards to provide a comprehensive evaluation.
Args:
prompts (list[str]): List of prompt texts.
completions (list[list[dict]]): List of completion dictionaries.
answer (list[str]): List of expected answers.
Returns:
list[float]: Combined rewards for each prompt-completion pair.
Explanation:
1. Calculates individual reward components:
- Correctness rewards (range: 0.0 to 2.0)
- Format rewards (range: 0.0 to 0.8)
2. Combines the rewards by adding them together.
3. Returns the combined scores with total range of 0.0 to 2.8.
"""
# Get individual rewards
correctness_scores = correctness_reward(prompts=prompts, completions=completions, answer=answer)
format_scores = format_reward(completions=completions)
# Combine rewards - correctness is weighted more heavily
combined_rewards = []
for c_score, f_score in zip(correctness_scores, format_scores):
# Correctness score range: 0.0 to 2.0
# Format score range: 0.0 to 0.8
# Total range: 0.0 to 2.8
combined_rewards.append(c_score + f_score)
return combined_rewards
def selective_log_softmax(logits, input_ids):
"""
Compute the log probabilities for the tokens specified in input_ids using a selective log-softmax.
Args:
logits (torch.Tensor): A tensor of shape (batch_size, seq_len, vocab_size) containing raw logits from the model.
input_ids (torch.Tensor): A tensor of shape (batch_size, seq_len) containing the token indices for which we want the log probabilities.
Returns:
torch.Tensor: A tensor of shape (batch_size, seq_len) where each element is the log probability
corresponding to the token in input_ids at that position.
Explanation:
1. F.log_softmax is applied along the vocabulary dimension (dim=-1) to convert logits into log probabilities.
2. The tensor input_ids is reshaped (via unsqueeze) to have an extra dimension so that we can use it as indices
in the log_probs tensor.
3. torch.gather collects the log probability at the index specified in input_ids for each position.
4. Finally, squeeze(-1) removes the extra dimension, returning a tensor with the same shape as input_ids.
"""
# Convert raw logits into log probabilities along the vocabulary axis.
log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_len, vocab_size)
# Reshape input_ids from (batch_size, seq_len) to (batch_size, seq_len, 1) for gathering.
# Then, gather the log probability for each token in input_ids.
selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))
# Remove the extra last dimension to get back to shape (batch_size, seq_len).
return selected_log_probs.squeeze(-1)
def compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep):
"""
Compute per-token log probabilities for a subset of tokens (typically the completion tokens).
Args:
model: The language model to use.
input_ids (torch.Tensor): Tensor of shape (batch_size, total_seq_len) containing token ids
for both prompt and completion.
attention_mask (torch.Tensor): Tensor of shape (batch_size, total_seq_len) indicating which tokens are real (1) or padding (0).
logits_to_keep (int): Number of tokens (from the completion part) for which we need log probabilities.
Returns:
torch.Tensor: Log probabilities for the last `logits_to_keep` tokens of each sequence.
Explanation:
1. We call the model with logits_to_keep + 1 so that the model outputs one extra logit than needed.
This is common in next-token prediction setups.
2. We slice off the last logit along the sequence dimension because it does not correspond to any input token.
3. We then restrict both the input_ids and logits to the last logits_to_keep tokens, which should
correspond to the generated completion portion.
4. Finally, we use the selective_log_softmax to compute log probabilities only for those tokens.
"""
# Run the model forward pass and obtain logits.
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1 # Request one extra logit for proper alignment.
).logits # Shape: (batch_size, total_seq_len, vocab_size)
# Remove the last logit as it does not have a corresponding target token.
logits = logits[:, :-1, :] # New shape: (batch_size, total_seq_len - 1, vocab_size)
# Slice the input_ids to keep only the last logits_to_keep tokens.
# This corresponds to the generated completion tokens.
input_ids = input_ids[:, -logits_to_keep:] # Shape: (batch_size, logits_to_keep)
# Also slice the logits to keep only those corresponding to the completion tokens.
logits = logits[:, -logits_to_keep:, :] # Shape: (batch_size, logits_to_keep, vocab_size)
# Compute and return the log probabilities for the selected tokens.
return selective_log_softmax(logits, input_ids)
def create_completion_mask(completion_ids, eos_token_id):
"""
Create a binary mask for the generated completion tokens so that tokens after the first EOS are ignored.
Args:
completion_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) with generated token ids.
eos_token_id (int): The token id representing the end-of-sequence.
Returns:
torch.Tensor: A mask tensor of shape (batch_size, seq_len) with 1s for tokens up to and including the first EOS
and 0s for tokens following the first EOS.
Explanation:
1. First, a boolean mask (is_eos) is created indicating where in the sequence the EOS token appears.
2. An index tensor (eos_idx) is initialized, assuming that no EOS is found (defaulting to the sequence length).
3. For sequences where EOS exists, eos_idx is updated to the position (index) of the first EOS.
4. A sequence index tensor is created that contains indices for each position in the sequence.
5. The final mask is computed by comparing the sequence indices to eos_idx (after adding a dimension).
"""
# Determine which positions in each sequence equal the EOS token.
is_eos = completion_ids == eos_token_id # Boolean tensor of shape (batch_size, seq_len)
# Initialize a tensor to store the index of the first EOS for each sequence.
# If no EOS is found, default to the full sequence length (is_eos.size(1)).
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
# Identify sequences that contain at least one EOS.
mask_exists = is_eos.any(dim=1)
# For sequences with an EOS, update eos_idx to the index of the first occurrence.
eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
# Create a tensor of indices [0, 1, 2, ..., seq_len-1] and replicate it for each sequence in the batch.
sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
# Build the mask: positions with an index less than or equal to the first EOS index are marked as 1.
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
return completion_mask
def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=32):
"""
Generate multiple completions for each prompt and create corresponding attention masks.
Args:
model: The language model used for generation.
tokenizer: The tokenizer to process the prompts and decode the outputs.
prompts (list of str): List of input prompt strings.
num_generations (int): Number of completions to generate per prompt.
max_completion_length (int): Maximum number of new tokens to generate for the completion.
Returns:
tuple: Contains the following tensors:
- prompt_ids: (batch_size * num_generations, prompt_seq_len)
- prompt_mask: (batch_size * num_generations, prompt_seq_len)
- completion_ids: (batch_size * num_generations, completion_seq_len)
- completion_mask: (batch_size * num_generations, completion_seq_len)
Explanation:
1. The prompts are tokenized and padded (with padding added to the left).
2. Each prompt is repeated num_generations times so that multiple completions are generated per prompt.
3. The model.generate() function is called to generate new tokens.
4. The generated output contains the prompt followed by the completion; we remove the prompt part to get the completions.
5. A mask is created (via create_completion_mask) so that only tokens up to the first EOS are considered.
"""
device = next(model.parameters()).device
# Tokenize the list of prompts with padding. The padding_side="left" ensures alignment on the right.
tokenizer.padding_side = "left"
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
prompt_ids = inputs["input_ids"].to(device) # Shape: (batch_size, prompt_seq_len)
prompt_mask = inputs["attention_mask"].to(device) # Shape: (batch_size, prompt_seq_len)
prompt_length = prompt_ids.size(1) # Save the prompt length to later separate prompt from completion.
# Repeat each prompt num_generations times.
prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0) # New shape: (batch_size*num_generations, prompt_seq_len)
prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0) # New shape: (batch_size*num_generations, prompt_seq_len)
# Generate new tokens for each prompt. The output includes the original prompt and the generated tokens.
outputs = model.generate(
prompt_ids,
attention_mask=prompt_mask,
max_new_tokens=max_completion_length,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Remove the prompt portion from the generated output to isolate the completion tokens.
completion_ids = outputs[:, prompt_length:] # Shape: (batch_size*num_generations, completion_seq_len)
# Create a binary mask that ignores tokens beyond the first EOS token.
completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
return prompt_ids, prompt_mask, completion_ids, completion_mask
def generate_rollout_data(model, ref_model, tokenizer, batch_samples, num_generations, max_completion_length):
"""
Generate rollouts and compute static log probabilities for both the old policy (current model)
and the reference model. Gradients are disabled so that these remain fixed.
Args:
model: The current model (policy) used to generate rollouts.
ref_model: The static reference model.
tokenizer: The tokenizer.
batch_samples: List of training samples.
num_generations: Number of completions to generate per prompt.
max_completion_length: Maximum completion length.
Returns:
A dictionary with rollout data including both old and reference log probabilities.
"""
tokenizer.padding_side = "left"
device = next(model.parameters()).device
# Extract prompts and answers.
prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]
answers = [sample["answer"] if isinstance(sample, dict) else sample[1] for sample in batch_samples]
# Generate completions and associated masks.
# We generate once, and then use the same completions to compute both sets of log probabilities.
with torch.no_grad():
prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
model, tokenizer, prompts, num_generations, max_completion_length
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
# Compute old_log_probs from the current model, with gradients disabled.
old_log_probs = compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep)
# Compute ref_log_probs from the reference model, which remains static.
ref_log_probs = compute_log_probabilities(ref_model, input_ids, attention_mask, logits_to_keep)
formatted_completions = [
[{'content': tokenizer.decode(ids, skip_special_tokens=True)}]
for ids in completion_ids
]
repeated_prompts = [p for p in prompts for _ in range(num_generations)]
repeated_answers = [a for a in answers for _ in range(num_generations)]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"completion_mask": completion_mask,
"old_log_probs": old_log_probs, # Static log probs from the current model (old policy)
"ref_log_probs": ref_log_probs, # Static log probs from the reference model
"formatted_completions": formatted_completions,
"repeated_prompts": repeated_prompts,
"repeated_answers": repeated_answers,
"logits_to_keep": logits_to_keep,
"batch_size": len(prompts),
"num_generations": num_generations
}
def compute_group_relative_advantages(rewards, num_generations):
"""
Compute group-relative advantages for each prompt group.
Args:
rewards (torch.Tensor): Tensor of shape (batch_size * num_generations) containing rewards.
num_generations (int): Number of completions generated per prompt.
Returns:
torch.Tensor: Tensor of advantages computed relative to the group mean.
"""
# Reshape rewards to group by prompt
rewards_by_group = rewards.view(-1, num_generations)
# Compute mean and standard deviation for each prompt group
group_means = rewards_by_group.mean(dim=1)
group_stds = rewards_by_group.std(dim=1)
# Expand the means and stds to match the original flat rewards tensor shape
expanded_means = group_means.repeat_interleave(num_generations)
expanded_stds = group_stds.repeat_interleave(num_generations)
# Normalize rewards to get advantages
advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
return advantages.unsqueeze(1) # Add dimension for token-wise operations
def maximize_grpo_objective(model, ref_model, rollout_data, tokenizer, reward_function,
optimizer, beta, epsilon):
"""
Update the policy model by maximizing the GRPO objective.
Args:
model: The current policy model.
ref_model: The reference model.
rollout_data: Dictionary containing rollout data.
tokenizer: The tokenizer.
reward_function: Function to compute rewards.
optimizer: The optimizer.
beta (float): KL penalty coefficient.
epsilon (float): Clipping parameter.
Returns:
float: The loss value.
"""
# Extract data from rollout
input_ids = rollout_data["input_ids"]
attention_mask = rollout_data["attention_mask"]
completion_mask = rollout_data["completion_mask"]
old_log_probs = rollout_data["old_log_probs"]
ref_log_probs = rollout_data["ref_log_probs"]
logits_to_keep = rollout_data["logits_to_keep"]
# Compute current log probabilities
current_log_probs = compute_log_probabilities(model, input_ids, attention_mask, logits_to_keep)
# Compute policy ratio
ratio = torch.exp(current_log_probs - old_log_probs)
# Get rewards data
formatted_completions = rollout_data["formatted_completions"]
repeated_prompts = rollout_data["repeated_prompts"]
repeated_answers = rollout_data["repeated_answers"]
# Compute rewards
rewards = torch.tensor(
reward_function(prompts=repeated_prompts, completions=formatted_completions, answer=repeated_answers),
dtype=torch.float32,
device=next(model.parameters()).device
)
avg_reward = rewards.mean().item()
print(f"Average Reward: {avg_reward:.4f}")
# Compute advantages using group-relative normalization
batch_size = rollout_data["batch_size"]
num_generations = rollout_data["num_generations"]
advantages = compute_group_relative_advantages(rewards, num_generations)
# Compute surrogate loss with clipping
surrogate1 = ratio * advantages
surrogate2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
surrogate_loss = torch.min(surrogate1, surrogate2)
# Compute KL divergence penalty
kl_div = torch.exp(ref_log_probs - current_log_probs) - (ref_log_probs - current_log_probs) - 1
# Combine losses
per_token_loss = surrogate_loss - beta * kl_div
loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Optimization step
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
optimizer.step()
return loss.item()
def train_with_grpo(model, tokenizer, train_data, num_iterations=1,
steps_per_iteration=500, batch_size=4, num_generations=4,
max_completion_length=128, beta=0.1, learning_rate=5e-6,
mu=3, epsilon=0.2, reward_function=combined_reward):
"""
Iterative Group Relative Policy Optimization algorithm.
Args:
model: The initial policy model to be fine-tuned.
tokenizer: The tokenizer used for encoding prompts and decoding completions.
train_data (list): List of training samples with "prompt" and "answer" fields.
num_iterations (int): Number of outer iterations (reward model updates).
steps_per_iteration (int): Number of policy update steps per iteration.
batch_size (int): Number of prompt samples per batch.
num_generations (int): Number of completions to generate per prompt.
max_completion_length (int): Maximum token length for completions.
beta (float): KL-divergence penalty coefficient.
learning_rate (float): Learning rate for optimizer.
mu (int): Number of GRPO updates per batch of generations.
epsilon (float): Clipping parameter for surrogate objective.
reward_function: Function that evaluates completions and returns rewards.
Returns:
The fine-tuned policy model.
"""
# Initialize policy model
policy_model = model
device = next(policy_model.parameters()).device
# Outer loop for iterations with reward model updates
for iteration in range(1, num_iterations + 1):
print(f"\nStarting iteration {iteration}/{num_iterations}")
# Create reference model for KL constraint
reference_model = copy.deepcopy(policy_model)
reference_model.eval()
for param in reference_model.parameters():
param.requires_grad = False
reference_model = reference_model.to(device)
# Initialize optimizer
optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
policy_model.train()
# Inner loop for policy updates
for step in range(1, steps_per_iteration + 1):
# Sample batch of prompts
batch_samples = random.sample(train_data, batch_size)
# Set old policy for this step
with torch.no_grad():
# Generate completions and compute log probs
rollout_data = generate_rollout_data(
policy_model, reference_model, tokenizer,
batch_samples, num_generations, max_completion_length
)
# Multiple GRPO updates per batch of generations
for grpo_iter in range(1, mu + 1):
loss_value = maximize_grpo_objective(
policy_model, reference_model, rollout_data, tokenizer,
reward_function, optimizer, beta, epsilon
)
print(f"Iteration {iteration}/{num_iterations}, Step {step}/{steps_per_iteration}, "
f"GRPO update {grpo_iter}/{mu}, Loss: {loss_value:.4f}")
# Optional: Update reward model here if using reward model training
# This is not implemented in the original code but present in the pseudocode
print(f"Completed iteration {iteration}. Reward model update would happen here.")
return policy_model
def optimize_model_memory(model):
"""Apply memory optimizations like proper gradient checkpointing setup"""
# Ensure model is in training mode
model.train()
# Disable caching for gradient checkpointing
model.config.use_cache = False
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Enable input gradients properly
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
def main():
"""
Main function to run the complete training and evaluation pipeline.
The process consists of:
1. Loading the pre-trained model and tokenizer.
2. Evaluating the initial model performance (before any finetuning).
3. Performing reinforcement learning (GRPO) finetuning.
4. Evaluating the final model after GRPO finetuning.
5. Saving the finetuned model and tokenizer.
Note: Functions such as prepare_dataset, evaluate_model, and reward_function
are assumed to be defined elsewhere.
"""
# Determine the device (GPU if available, otherwise CPU) from the model's parameters.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define the model name and output directory.
model_name = "Qwen/Qwen2.5-0.5B-Instruct" # The 0.5B model is not smart enough
# to generate the <reasoning> and <answer> tags
# so several iterations of SFT to teach it these tags
# are recommended before RL
output_dir = "math_solver_model"
# Load the pre-trained causal language model.
# - torch_dtype specifies the precision (bfloat16 for efficiency on supported hardware).
# - attn_implementation selects an optimized attention mechanism.
# - device_map="auto" automatically distributes the model across available devices.
print("Downloading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
#attn_implementation="flash_attention_2",
device_map=None
)
print("Downloaded model")
# Move the model to the determined device.
model = model.to(device)
# Load the tokenizer corresponding to the model.
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# Set the pad token to be the same as the end-of-sequence token.
tokenizer.pad_token = tokenizer.eos_token
# Update the model configuration with the correct token IDs.
model.config.pad_token_id = tokenizer.eos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# -------------------------------
# Step 0: INITIAL EVALUATION
# -------------------------------
# Load the complete training dataset using a helper function (assumed defined elsewhere).
all_data = prepare_dataset("train")
# Randomize the order of examples.
random.shuffle(all_data)
# Use a small subset (e.g., 30 examples) for evaluation.
num_eval_examples = 1
eval_data = all_data[:num_eval_examples]
# Evaluate the initial performance of the model before any finetuning.
print("\nInitial model evaluation before GRPO:")
pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
print(f"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%")
model = optimize_model_memory(model)
# -------------------------------
# Step 1: RL FINETUNING (GRPO)
# -------------------------------
print("\nStarting RL finetuning using GRPO...")
# Use the remaining examples (beyond the evaluation subset) for RL finetuning.
train_data = all_data[num_eval_examples:]
# Define RL training configuration.
training_config = {
'num_iterations' : 1,
'steps_per_iteration': 500, # Total number of RL training steps.
'batch_size': 4, # Number of samples per training step.
'num_generations': 16, # Number of completions generated per prompt.
'max_completion_length': 500, # Maximum token length for each generated completion.
'beta': 0.04, # KL divergence penalty coefficient.
'learning_rate': 5e-6, # Learning rate for RL fine-tuning.
'mu': 1,
'epsilon': 0.1,
'reward_function': combined_reward
}
# Fine-tune the model using GRPO RL training.
model = train_with_grpo(
model=model,
tokenizer=tokenizer,
train_data=train_data,
**training_config
)
# -------------------------------
# Step 2: FINAL EVALUATION & SAVING
# -------------------------------
print("\nFinal model evaluation after GRPO RL finetuning:")
# Evaluate the final model performance using the evaluation dataset.
post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
print(f"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%")
print(f"Total Improvement: {post_grpo_accuracy - pre_grpo_accuracy:.2f}%")
print("\nSaving GRPO finetuned model...")
# Save the final finetuned model and tokenizer to disk.
model.save_pretrained("grpo_finetuned_model")
tokenizer.save_pretrained("grpo_finetuned_model")
if __name__ == "__main__":
main()