-
-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐞 Bug: aram estimate issue with gemma2 Ollama model #117
Comments
while I'm passively offering to look at code in exchange for some orientation -- I'd love it if there was a cli flag for this command that let you specify the top context length to search for as a power of 2 -- for example |
That’s a really good idea! Re: NaNs: that’s “cool” 😂, will look at this as well |
hey @sammcj I recognize you I think from the Ollama discord :) thank you for the help! |
It's not just Gemma - it's a lot (all?) models! Not quite sure how that happened but I'm looking into it. |
Was a bit to it, but got there 😅 https://github.com/sammcj/gollama/pull/118/files
|
awesome :) thank you. Here's a script I've written now to make use of it to simplify the most frequent task I have with this vram command (this is written for a Mac, and would be complicated to add to gollama across various architectures): #!/usr/bin/env python3
import argparse
import subprocess
import re
import sys
import logging
def run_command(command):
"""Run a shell command and return its output."""
try:
result = subprocess.run(command, stdout=subprocess.PIPE,
text=True, check=True)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
logging.error(f"Error running command {' '.join(command)}: {e}")
sys.exit(1)
def extract_quant(model_name):
"""Extract the quant value from gollama -l output for the given model."""
output = run_command(["gollama", "-l"])
pattern = re.compile(rf"{re.escape(model_name)}\s+\S+\s+(\S+)")
match = pattern.search(output)
if match:
quant = match.group(1)
logging.info(f"Quant value for {model_name}: {quant}")
return quant
else:
logging.error(f"Could not find quant value for model '{model_name}'")
sys.exit(1)
def extract_vram_nth(model_name):
"""Extract the context length (vram_nth) from ollama show <model> output."""
output = run_command(["ollama", "show", model_name])
pattern = re.compile(r"context length\s+(\d+)")
match = pattern.search(output)
if match:
vram_nth = match.group(1)
logging.info(f"VRAM nth for {model_name}: {vram_nth}")
return vram_nth
else:
logging.error(f"Could not extract context length for model '{model_name}'")
sys.exit(1)
def run_vram_estimation(model_name, vram_nth, fits_limit):
"""Run gollama -vram estimation and return the output."""
logging.info(f"Running gollama -vram for {model_name} with fits={fits_limit} GB")
output = run_command([
"gollama", "-vram", model_name, "--fits", str(fits_limit),
"--vram-to-nth", vram_nth
])
return output
def find_largest_below_fits(vram_output, quant, fits):
"""Find the largest A, B, and C values below the fits limit, along with their column names."""
lines = vram_output.splitlines()
header = lines[0]
separator = lines[1]
labels = lines[2]
rows = lines[3:]
logging.info("VRAM output header, labels, and rows gathered")
# Find the quant row
quant_row = None
for row in rows:
if quant in row:
quant_row = row
break
if not quant_row:
logging.error(f"Could not find matching row for quant '{quant}'")
sys.exit(1)
logging.info(f"Quant row: {quant_row}")
# Extract column names from labels
column_names = [col.strip() for col in labels.split('|')[3:]]
columns = quant_row.split("|")[3:]
max_A, max_A_ctx = None, None
max_B, max_B_ctx = None, None
max_C, max_C_ctx = None, None
for idx, col in enumerate(columns):
col = col.strip()
match = re.match(r'([\d\.]+)(?:\(([\d\.]+),\s*([\d\.]+)\))?', col)
if match:
A_val = float(match.group(1))
B_val = float(match.group(2) or 0)
C_val = float(match.group(3) or 0)
ctx_size = column_names[idx + 1] if idx + 1 < \
len(column_names) else "Unknown"
if A_val <= fits and (max_A is None or A_val >= max_A):
max_A = A_val
max_A_ctx = ctx_size
if B_val <= fits and (max_B is None or B_val >= max_B):
max_B = B_val
max_B_ctx = ctx_size
if C_val <= fits and (max_C is None or C_val >= max_C):
max_C = C_val
max_C_ctx = ctx_size
logging.info(f"Max A: {max_A} at {max_A_ctx}")
logging.info(f"Max B: {max_B} at {max_B_ctx}")
logging.info(f"Max C: {max_C} at {max_C_ctx}")
if max_A is not None or max_B is not None or max_C is not None:
final_output = f"{max_A_ctx}@{max_A} ({max_B_ctx}@{max_B}, {max_C_ctx}@{max_C})"
logging.info(f"Final Output: {final_output}")
return header, labels, separator, final_output
else:
logging.error(f"No values found below the fits limit of {fits} GB")
sys.exit(1)
def get_default_fits():
"""Get the default fits value from sysctl iogpu.wired_limit_mb."""
output = run_command(["sysctl", "iogpu.wired_limit_mb"])
match = re.search(r"(\d+)", output)
if match:
wired_limit_mb = int(match.group(1))
fits = wired_limit_mb / 1024 # Convert MB to GB
logging.info(f"Default fits value from sysctl: {fits} GB")
return fits
else:
logging.error("Could not retrieve iogpu.wired_limit_mb from sysctl")
sys.exit(1)
if __name__ == "__main__":
parser = argparse\
.ArgumentParser(description="Estimate VRAM usage for a given model.")
parser.add_argument("model_name", help="Name of the model")
parser.add_argument(
"--fits", type=float, default=None,
help="Fits limit in GB (default: iogpu.wired_limit_mb / 1024)"
)
parser.add_argument("--verbose", "-v",
action="store_true", help="Verbose output")
args = parser.parse_args()
# Set up logging with 'VERBOSE' instead of 'DEBUG'
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.ERROR,
format='VERBOSE: %(message)s' if args.verbose else '%(message)s',
stream=sys.stderr)
model_name = args.model_name
fits_limit = args.fits or get_default_fits()
quant = extract_quant(model_name)
vram_nth = extract_vram_nth(model_name)
vram_output = run_vram_estimation(model_name, vram_nth, fits_limit)
header, labels, separator, largest_col = \
find_largest_below_fits(vram_output, quant, fits_limit)
print(f"Using fits value: {fits_limit:.2f} GB")
print(largest_col) sample output:
|
much, much better than my old script -- your vram thing is what drew me to def estimate_model_memory(context_length, param_count, quant_bits, n_heads=16, embedding_length=2048):
"""
Estimate total memory requirements for a transformer model.
Args:
context_length (int): Maximum sequence length the model can handle.
param_count (int): Total number of parameters in the model.
quant_bits (float): Average number of bits used for quantization.
n_heads (int): Number of attention heads (default 12).
embedding_length (int): Dimension of the model (default 768).
Returns:
float: Total memory estimate in bytes.
"""
# Estimate model size in bytes
model_size_bytes = param_count * quant_bits / 8
# KV cache size and activation memory
kv_cache_size = 2 * context_length * embedding_length * n_heads * quant_bits / 8
activation_memory = 4 * context_length * embedding_length
# Total memory estimate
return model_size_bytes + kv_cache_size + activation_memory
# for quantization bits see https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/README.md
# Example usage
# -- gemma2:27b-instruct-q6_K
# context_length = 8192
# param_count = 272e8
# quant_bits = 6.56 # q6_K quantization
# n_heads = 16
# embedding_length = 4608
# total_memory = estimate_model_memory(context_length, param_count, quant_bits, n_heads, embedding_length)
# print(f"Total Estimated Memory: {total_memory / (1024 ** 2):.2f} MB")
# Total Estimated Memory: 22359.39 MB
# -- mistral-nemo:12b-instruct-2407-q8_0
# context_length = 65536#1.024e+06
# param_count = 122e8
# quant_bits = 8 # q8_0 quantization
# n_heads = 32
# embedding_length = 5120
# total_memory = estimate_model_memory(context_length, param_count, quant_bits, n_heads, embedding_length)
# print(f"Total Estimated Memory: {total_memory / (1024 ** 2):.2f} MB")
# Total Estimated Memory: 351634.83 MB
# @ 65536 Total Estimated Memory: 33394.83 MB
# -- llama3.1:70b-instruct-q3_K_M
# context_length = 6144#131072
# param_count = 706e8
# quant_bits = 3.89
# n_heads = 64
# embedding_length = 6144#8192
# total_memory = estimate_model_memory(context_length, param_count, quant_bits, n_heads, embedding_length)
# print(f"Total Estimated Memory: {total_memory / (1024 ** 2):.2f} MB")
# Total Estimated Memory: 100568.68 MB
# @ 8192 Total Estimated Memory: 36978.28 MB
# @ 6144/6144 Total Estimated Memory: 35123.56 MB
# -- llama3.1:8b-instruct-q8_0
# context_length = 65536#32768#131072
# param_count = 8e9
# quant_bits = 8
# n_heads = 32
# embedding_length = 8192
# total_memory = estimate_model_memory(context_length, param_count, quant_bits, n_heads, embedding_length)
# print(f"Total Estimated Memory: {total_memory / (1024 ** 2):.2f} MB")
# Total Estimated Memory: 77261.39 MB
# @ 32768 Total Estimated Memory: 25037.39 MB
# -- qwen2.5:32b-instruct-q5_K_M
context_length = 32768#131072
param_count = 32e9
quant_bits = 5.67
n_heads = 40
embedding_length = 5120
total_memory = estimate_model_memory(context_length, param_count, quant_bits, n_heads, embedding_length)
print(f"Total Estimated Memory: {total_memory / (1024 ** 2):.2f} MB")
# Total Estimated Memory: 60477.33 MB
# @ 32768 Total Estimated Memory: 31341.33 MB``` |
Im actually not clear about the three numbers though 🥹 |
Description
For some reason of all my models Gemma2 doesn't get good vram estimates.
Environment
If applicable, add screenshots to help explain your problem.
go version
)Apple M3 Max, Sequoia 15.0
from releases, macOS
go version go1.22.3 darwin/arm64
Can you contribute?
Im not very familiar with
go
but if I get a good sense of what to look at I would be happy to help.The text was updated successfully, but these errors were encountered: