forked from PaulKMueller/llama_traffic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparse_waymo.py
56 lines (43 loc) · 2.01 KB
/
parse_waymo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
def get_feature_description(example_proto):
# Create a description of the features
feature_description = {}
for key, feature in example_proto.features.feature.items():
kind = feature.WhichOneof("kind")
if kind == "bytes_list":
dtype = tf.string
feature_description[key] = tf.io.VarLenFeature(dtype)
elif kind == "float_list":
dtype = tf.float32
feature_description[key] = tf.io.VarLenFeature(dtype)
elif kind == "int64_list":
dtype = tf.int64
feature_description[key] = tf.io.VarLenFeature(dtype)
else:
raise ValueError(f"Unsupported feature type: {kind}")
return feature_description
def parse_tfrecord(tfrecord_path):
# Load the dataset
dataset = tf.data.TFRecordDataset(tfrecord_path)
output = ""
for raw_record in dataset.take(1): # Taking only one record to infer structure
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
feature_description = get_feature_description(example)
# Parse the record into tensors
parsed_record = tf.io.parse_single_example(raw_record, feature_description)
for key, feature in parsed_record.items():
if isinstance(feature, tf.SparseTensor):
value = tf.sparse.to_dense(feature).numpy()
else:
value = feature.numpy()
output = output + "\n" + f"Feature: {key}"
output = output + "\n" + f" - Value: {value}"
output = output + "\n" + f" - Shape: {value.shape}"
output = output + "\n" + f" - DataType: {feature.dtype}\n"
return output
# Provide the path to your TFRecord file
tfrecord_path = "/mrtstorage/datasets/tmp/waymo_open_motion_v_1_2_0/uncompressed/tf_example/training/training_tfexample.tfrecord-00998-of-01000"
structure = parse_tfrecord(tfrecord_path)
with open("data/structure_with_datatypes.txt", "w") as file:
file.write(str(structure))