import torch
import triton
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Tuple
def measure_matmul_flops(M: int, N: int, K: int, dtype=torch.bfloat16, warmup: int = 25, rep: int = 100) -> float:
device = torch.cuda.current_device()
# Create random matrices on GPU
A = torch.randn(M, K, dtype=dtype, device=device)
B = torch.randn(K, N, dtype=dtype, device=device)
# Define the operation to benchmark
def matmul_op():
return torch.matmul(A, B)
# Measure time using Triton's do_bench
time_ms = triton.testing.do_bench(matmul_op, warmup=warmup, rep=rep)
# Calculate FLOPS
# For matrix multiplication: 2 * M * N * K operations (multiply + add for each element)
total_ops = 2 * M * N * K
flops_per_second = total_ops / (time_ms * 1e-3) # Convert ms to seconds
tflops_per_second = flops_per_second / 1e12 # Convert to TFLOPs
return tflops_per_second
matrix_sizes = [512, 1024, 2048, 4096, 8192, 16384]
for size in matrix_sizes:
print(f"Testing {size}x{size} matrices...", end=" ")
tflops = measure_matmul_flops(size, size, size)
print(f"{tflops:.2f} TFLOPs")