-
Notifications
You must be signed in to change notification settings - Fork 463
/
Copy pathinterleave_ffma.py
137 lines (115 loc) · 4.45 KB
/
interleave_ffma.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import mmap
import os
import re
import subprocess
from torch.utils.cpp_extension import CUDA_HOME
def run_cuobjdump(file_path):
command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
return result.stdout
def extract_ffma(sass):
lines = sass.splitlines()
collected = []
current = []
arch_name, func_name = 'N/A', 'N/A'
skip_next_line = False
for line in lines:
if 'code for' in line:
arch_name = line.lstrip().lstrip('code for ').rstrip()
elif 'Function :' in line:
func_name = line.lstrip().lstrip('Function :').rstrip()
elif 'FFMA' in line:
current.append(line)
skip_next_line = True
elif skip_next_line:
current.append(line)
skip_next_line = False
else:
if len(current) >= 16:
assert len(current) % 2 == 0
collected.append((f'{arch_name}::{func_name}', current))
current = []
if os.getenv('DG_PRINT_REG_REUSE', None):
print(f'Found {len(collected)} FFMA segments')
return collected
def extract_hex_from_line(line):
match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
assert match
return int(match.group(1), 16)
def validate(m, offset, le_bytes, num_lines):
assert len(le_bytes) == num_lines // 2
assert m[offset:offset + 16] == le_bytes[0]
for i in range(1, num_lines // 2):
if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
return False
return True
def parse_registers(line):
line = re.sub(r'/\*.*?\*/', '', line)
line = line.replace(';', '')
tokens = line.strip().split(',')
registers = []
for token in tokens:
token = token.strip()
words = token.split()
for word in words:
if word.startswith('R'):
reg = word.split('.')[0]
registers.append(reg)
return registers
def modify_segment(m, name, ffma_lines):
num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2
assert num_lines % 2 == 0
le_bytes, new_le_bytes = [], []
reused_list = []
dst_reg_set = set()
last_reused, last_dst_reg = False, ''
num_changed = 0
for i in range(num_lines // 2):
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
reused = (high_hex & 0x0800000000000000) != 0
if reused:
is_first_occurred = dst_reg not in dst_reg_set
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
# Modify the `reuse` and `yield` bits
assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
high_hex ^= 0x0800200000000000
reused = False
num_changed += 1
else:
reused_list.append(i)
dst_reg_set.add(dst_reg)
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
last_reused, last_dst_reg = reused, dst_reg
if os.getenv('DG_PRINT_REG_REUSE', None):
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
# Find the offset
offsets = []
offset = m.find(le_bytes[0])
while offset != -1:
offsets.append(offset)
offset = m.find(le_bytes[0], offset + 1)
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
# Replace with `new_le_bytes`
for offset in offsets:
for i in range(num_lines // 2):
m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
def process(path):
if os.getenv('DG_PRINT_REG_REUSE', None):
print(f'Processing {path}')
output = run_cuobjdump(path)
segments = extract_ffma(output)
with open(path, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
for segment in segments:
modify_segment(mm, *segment)
mm.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
parser.add_argument('--so', help='Path to the SO file')
args = parser.parse_args()
process(args.so)