Dunfey · Hotel WWDC as data, est. 1983
Front desk everything
Years
Topics

2026 Graphics & GamesAI & Machine Learning

WWDC26 · 16 min · Graphics & Games / AI & Machine Learning

Optimize custom machine learning operations with Metal tensors

Unlock powerful machine learning performance with the Metal Tensor API and Metal Performance Primitives (MPP) Tensor Ops library. Discover how to create portable operations that take advantage of Neural Accelerators in Apple M5 and A19 GPUs. Learn to build custom machine learning kernels for your Core AI applications, and find out how to work effectively with quantized data formats and GPU memory optimization.

Watch at developer.apple.com ↗

Transcript all transcripts

Chapters

  • 0:00 — Introduction
  • 0:21 — Apple's ML software stack
  • 2:25 — Managing quantized data
  • 4:23 — Multi-plane tensors
  • 5:17 — Quantized matrix multiplication
  • 9:31 — Building advanced ops
  • 13:35 — Integrating custom ops into Core AI
  • 15:25 — Next steps

Code shown on screen · 9 snippets

Create a quantized MTLTensor objectivec · at 3:53 ↗
// Creating a tensor with a quantized data type from device

#define RANK 2

MTLTensorDescriptor *tensorDesc = [MTLTensorDescriptor new];

tensorDesc.dataType = MTLTensorDataTypeMetalFloat8E4M3;
tensorDesc.usage = MTLTensorUsageCompute;

NSInteger dimensions[RANK] = {NumCols, NumRows};
tensorDesc.dimensions = [[MTLTensorExtents alloc] initWithRank:RANK values:dimensions];

NSError *err = nil;
id <MTLTensor> tensor = [device newTensorWithDescriptor:tensorDesc error:&err];
Declare a multi-plane tensor with scale factors objectivec · at 4:48 ↗
// Creating a tensor with a scales auxiliary plane from device

#define RANK 2

MTLTensorAuxiliaryPlaneDescriptor *planeDesc = [MTLTensorAuxiliaryPlaneDescriptor new];
planeDesc.dataType = MTLTensorDataTypeMetalFloat8UE8M0;

NSInteger blockFactors[RANK] = {32, 1};
planeDesc.blockFactors = [[MTLTensorExtents alloc] initWithRank:RANK values:blockFactors];

MTLTensorAuxiliaryPlaneDescriptorMap *auxiliaryPlanes =
    [MTLTensorAuxiliaryPlaneDescriptorMap new];
[auxiliaryPlanes setDescriptor:planeDesc forPlane:MTLTensorPlaneTypeScales];

MTLTensorDescriptor *tensorDesc = [MTLTensorDescriptor new];
tensorDesc.dataType = MTLTensorDataTypeMetalFloat8E4M3;
tensorDesc.usage = MTLTensorUsageCompute;

NSInteger dimensions[RANK] = {NumCols, NumRows};
tensorDesc.dimensions = [[MTLTensorExtents alloc] initWithRank:RANK values:dimensions];
tensorDesc.auxiliaryPlanes = auxiliaryPlanes;

NSError *err = nil;
id <MTLTensor> tensor = [device newTensorWithDescriptor:tensorDesc error:&err];
MSL type aliases for an MXFP8 tensor handle objectivec · at 6:07 ↗
// Type aliases for a MXFP8 multi-plane tensor handle

#include <metal_tensor>

using namespace metal;

using scales_plane = tensor_blockwise<tensor_plane_scales,
                                      device metal_fp8_ue8m0_format,
                                      32, 1>;

using mxfp8_tensor = tensor<device metal_fp8_e4m3_format,
                            dextents<int, 2>,
                            tensor_handle,
                            scales_plane>;

kernel void matmul(mxfp8_tensor matrixA [[buffer(0)]],
                   mxfp8_tensor matrixB [[buffer(1)]],
                   tensor<device half, dextents<int, 2>> matrixC [[buffer(2)]])
{
    // ...
}
Declare an inline MXFP8 tensor on the stack objectivec · at 6:51 ↗
// Type aliases for a MXFP8 multi-plane tensor inline

#include <metal_tensor>

using namespace metal;

using scales_plane = tensor_blockwise<tensor_plane_scales,
                                      device metal_fp8_ue8m0_format,
                                      32, 1>;

using mxfp8_tensor_inline = tensor<device metal_fp8_e4m3_format,
                                   dextents<int, 2>,
                                   tensor_inline,
                                   scales_plane>;

// Construct tensor on the stack from buffer pointers
mxfp8_tensor_inline matrixA(dataBufferA,
                             dextents<int, 2>(K, M),
                             array<int, 2>({ 1, K }),
                             scales_plane(scalesBufferA));
Slice tensors and run a quantized matmul objectivec · at 7:19 ↗
// Slice the tensors to extract the relevant tile
auto tA = matrixA.slice(0, tgid.y * TILEM);
auto tB = matrixB.slice(tgid.x * TILEN, 0);
auto tC = matrixC.slice(tgid.x * TILEN, tgid.y * TILEM);

// Set up the matmul descriptor
constexpr auto descriptor = matmul2d_descriptor(TILEM,                  // M
                                                TILEN,                  // N
                                                dynamic_length_v<int>,  // K
                                                false,   // Left matrix transposed
                                                false);  // Right matrix transposed

matmul2d<descriptor, execution_simdgroups<4>> op;

// Run the op — TensorOps handles dequantization automatically
op.run(tA, tB, tC);
Set up simdgroup-scoped QxK multiplication objectivec · at 10:27 ↗
// Setup QxK matrix multiplication op
constexpr auto mul_qk_op_desc = matmul2d_descriptor(/* ... */);
matmul2d<mul_qk_op_desc, execution_simdgroups> mul_qk_op;

// Slice Q, K, V
auto tQSlice = tQ.slice<D, ROWS_PER_SIMD>(0, sgid * ROWS_PER_SIMD);
auto tKSlice = tK.slice<D, BK>(0, k);
auto tVSlice = tV.slice<D, BK>(0, k);

// Create cooperative tensor to store tile of QxK
auto ctQK = mul_qk_op.get_destination_cooperative_tensor<decltype(tQSlice),
                                                         decltype(tKSlice),
                                                         float>();

// Multiply QxK
mul_qk_op.run(tQSlice, tKSlice, ctQK);
Compute row-wise reduction for SoftMax objectivec · at 11:18 ↗
// Create a cooperative tensor to store row reduction output
auto ctTileRowMax = mul_qk_op.get_row_reduction_destination_cooperative_tensor<
                        decltype(tQSlice),
                        decltype(tKSlice),
                        float>();

// Compute max over each row of QxK tile
reduce_rows(ctQK, ctTileRowMax, reduction_operation::max, -INFINITY);
Compute element-wise SoftMax with map_iterator objectivec · at 11:56 ↗
// Iterate over elements of QxK tile
#pragma clang loop unroll(full)
for (auto it = ctQK.begin(); it != ctQK.end(); it++) {
    // Fetch row max corresponding to this element
    auto row_it = ctRowMax.map_iterator(it);

    // Subtract row max from each element and compute exponent
    *it = exp(*it - *row_it);
}
Reuse cooperative tensor as matmul input objectivec · at 12:33 ↗
constexpr auto mul_sv_op_desc = matmul2d_descriptor(/* ... */);
matmul2d<mul_sv_op_desc, metal::execution_simdgroup> mul_sv_op;

if (mul_sv_op.is_compatible_as_left_input<float, half, float>(ctQK)) {
    // Directly reuse cooperative tensor as input
    auto ctQKIn = mul_sv_op.get_left_input_cooperative_tensor<float, half, float>(ctQK);
    mul_sv_op.run(ctQKIn, tVSlice, ctO);
} else {
    // Store and reload through threadgroup memory if layout is not compatible
    ctQK.store(tgTensor);
    simdgroup_barrier(mem_flags::mem_threadgroup);

    auto ctQKIn = mul_sv_op.get_left_input_cooperative_tensor<float, half, float>();
    ctQKIn.load(tgTensor);
    mul_sv_op.run(ctQKIn, tVSlice, ctO);
}

Resources