-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cuda graph #67
base: main
Are you sure you want to change the base?
Cuda graph #67
Conversation
@@ -280,9 +277,7 @@ def get_args(): | |||
parser.add_argument("-o", "--output-len", type=int, default=32, help="length of output sequence") | |||
parser.add_argument("-n", "--num-requests", type=int, default=1000, help="number of requests to generate") | |||
parser.add_argument("-p", "--profile-dir", type=str, default=None, help="directory to store torch profiler output") | |||
parser.add_argument("-c", "--cuda-graph", action="store_true", default=False, help="enable cuda graph for all modules") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this parameter will cause AttributeError in model_config.enable_cuda_graph_attn = args.cuda_graph or args.cuda_graph_attn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 99
DisagMoE/benchmark/benchmark_serving.py
Line 99 in b4713a7
model_config.enable_cuda_graph_attn = args.cuda_graph or args.cuda_graph_attn |
Refactor cuda graph execution