# ONNX Optimizer Examples

This page provides worked examples for the QAIRT ONNX Optimizer, progressing from
simple single-API workflows to advanced pass pipelines and custom pass authoring.

Each example is a complete, runnable Python script. The only prerequisites are a
QAIRT SDK installation and an ONNX model appropriate for the transformation being
demonstrated.

Note

The optimizer must be applied before [`qairt.convert()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-core-api.html#qairt.convert). All examples below
produce an optimized ONNX model (and, where applicable, updated encodings) that can
be passed directly to the converter.

- [Simple Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#simple-examples)

    - [Change Sequence and Context Length](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#change-sequence-and-context-length)
    - [MHA to SHA Conversion](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#mha-to-sha-conversion)
    - [MoE Adaptation](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#moe-adaptation)
- [Advanced Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#advanced-examples)

    - [Composing a Custom Pass Pipeline](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#composing-a-custom-pass-pipeline)
- [Custom Pass Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#custom-pass-examples)

    - [Writing and Integrating a Custom Pass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#writing-and-integrating-a-custom-pass)

—

## [Simple Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id1)

These examples demonstrate the most common optimizer tasks using the high-level API.

### [Change Sequence and Context Length](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id2)

Use [`change_seq_and_context_length()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.change_seq_and_context_length) to rewrite the
sequence length (AR) and context length (CL) of an LLM without re-quantization.

The constraint `1 ≤ new_seq_length ≤ new_context_length - 1` must hold.

from qairt.optimizer.onnx import GraphContext, change_seq_and_context_length
    
    # Step 1: Load the model.
    ctx = GraphContext.from_files(
        model_path="llm_model.onnx",
        encodings_path="llm_model.encodings",
    )
    
    # Step 2: Change AR to 128 and CL to 2048.
    # The function modifies ctx in-place and returns the same ctx.
    # Use copy.deepcopy(ctx) first if you need to preserve the original.
    change_seq_and_context_length(ctx, new_seq_length=128, new_context_length=2048)
    
    # Step 3: Export.
    exported = ctx.export(path="./output_reshaped", prefix="model_ar128_cl2048")
    print(f"Model saved to: {exported.onnx_path}")
    Copy to clipboard

If your model uses non-standard tensor names that the optimizer cannot auto-detect,
provide explicit axis denotation seed rules via `axis_denotation_config`:

from qairt.optimizer.onnx import (
        GraphContext,
        change_seq_and_context_length,
        AxisDenotationConfig,
        AxisDenotationSeedRule,
        AxisDenotation,
    )
    
    # Step 1: Load the model.
    ctx = GraphContext.from_files(
        model_path="llm_model.onnx",
        encodings_path="llm_model.encodings",
    )
    
    # Step 2: Build a config that maps a custom input tensor to its axis semantics.
    config = AxisDenotationConfig(
        custom_seed_rules=[
            AxisDenotationSeedRule(
                name_pattern=r"my_custom_input",
                denotations=[AxisDenotation.BATCH, AxisDenotation.SEQ_LENGTH],
            )
        ]
    )
    
    # Step 3: Change AR to 128 and CL to 4096 using the custom config.
    change_seq_and_context_length(ctx, new_seq_length=128, new_context_length=4096,
                                  axis_denotation_config=config)
    
    # Step 4: Export.
    exported = ctx.export(path="./output_reshaped", prefix="model_ar64_cl1024")
    print(f"Model saved to: {exported.onnx_path}")
    Copy to clipboard

### [MHA to SHA Conversion](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id3)

Use [`convert_mha_to_sha()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.convert_mha_to_sha) to convert all Multi-Head
Attention blocks to Single-Head Attention and apply full layout optimization in
one call.

from qairt.optimizer.onnx import GraphContext, convert_mha_to_sha
    
    # Step 1: Load the model.
    ctx = GraphContext.from_files(
        model_path="llm_model.onnx",
        encodings_path="llm_model.encodings",   # omit for float models
    )
    
    # Step 2: Run MHA→SHA with full layout optimization.
    # - m2s_head_split_map: control how heads are split, e.g. {-1: 1} splits all
    #   heads to size 1 (default behaviour for most models).
    # - permute_kv_cache_io: permute the KV-cache inputs/outputs from
    #   [batch, head, ...] to [head, batch, ...] — a more HTP-friendly
    #   layout that gives better on-target performance.  Enable for any LLM
    #   that has a KV cache (essentially all transformer-based LLMs); leave
    #   False only for non-LLM models or models without a KV cache.
    # - enable_experimental_layout_optimization: enable only if the default layout
    #   optimization causes performance regression on your specific model.
    convert_mha_to_sha(ctx, permute_kv_cache_io=True)
    
    # Step 3: Export.
    exported = ctx.export(path="./output_mha2sha", prefix="model_sha")
    print(f"Model saved to: {exported.onnx_path}")
    Copy to clipboard

### [MoE Adaptation](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id4)

Use [`adapt_moe()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.adapt_moe) to prepare a Mixture-of-Experts (MoE)
LLM for deployment. The function adapts the AR=N and AR=1 expert sub-graphs,
inlines local functions, and cleans up dead code.

from qairt.optimizer.onnx import GraphContext, adapt_moe
    
    # Step 1: Load the model.  Pass encodings_path for quantized models.
    ctx = GraphContext.from_files(
        model_path="moe_model.onnx",
        encodings_path="moe_model.encodings",   # omit for float models
    )
    
    # Step 2: Adapt the MoE structure.
    # - overridden_subselection: override the number of experts selected per token
    #   (inferred from the model if not specified).
    # - remove_op_predicate: set True to remove op-predicate Where ops.
    # - enable_validation: set True to verify outputs with ONNX Runtime (slower).
    adapt_moe(ctx)
    
    # Step 3: Export the optimized model and encodings.
    exported = ctx.export(path="./output_moe", prefix="model_moe_adapted")
    print(f"Model saved to: {exported.onnx_path}")
    if exported.encodings_path:
        print(f"Encodings saved to: {exported.encodings_path}")
    Copy to clipboard

* * *

## [Advanced Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id5)

### [Composing a Custom Pass Pipeline](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id6)

For scenarios where the high-level APIs are not sufficient, you can compose passes
manually. The following example applies a sequence-length change (using the
underlying [`IOShapeRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.io_shape_rewriter.IOShapeRewriter)
pass directly), then applies [`ParallelizeSerialOpsRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.simplification.ParallelizeSerialOpsRewriter)
to restructure serial associative ops into balanced trees, and finally demonstrates
an experimental [`LinearToConvPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.LinearToConvPass).

Warning

[`LinearToConvPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.LinearToConvPass)
is an experimental pass and is not guaranteed to be stable or general.
Validate its outputs thoroughly before use in a production workflow.

from qairt.optimizer.onnx import GraphContext, change_seq_and_context_length
    from qairt.optimizer.onnx.passes import ShapeInference
    from qairt.optimizer.onnx.passes.simplification import (
        ParallelizeSerialOpsRewriter,
    )
    from qairt.optimizer.onnx.passes.experimental import (
        CanonizeGemmPass,
        LinearToConvPass,
    )
    from qairt.optimizer.onnx.passes import LayoutOptRewriter
    from qairt.optimizer.onnx.passes.cleaning import DeadCodeRemovalRewriter
    
    # Step 1: Load model.
    ctx = GraphContext.from_files(
        model_path="llm_model.onnx",
        encodings_path="llm_model.encodings",
    )
    
    # Step 2: Modify sequence / context length using the high-level API.
    change_seq_and_context_length(ctx, new_seq_length=128, new_context_length=2048)
    
    # Step 3: Parallelize long chains of serial associative ops (Add, Mul, etc.).
    # The default config targets chains of 3+ serial nodes.
    config = ParallelizeSerialOpsRewriter.Config(
        op_types=("Add", "Mul"),   # restrict to these op types
        op_num_threshold=4,        # only rewrite chains of 4+ nodes
    )
    rewritten = ParallelizeSerialOpsRewriter(config).apply(ctx)
    print(f"Parallelized {rewritten} serial op chain(s).")
    
    # Step 4 (Experimental): Convert MatMul/Gemm to Conv1x1.
    #   NOTE: This is experimental — validate outputs before production use.
    Copy to clipboard

Warning

[`CanonizeGemmPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.CanonizeGemmPass) must be
called before [`LinearToConvPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.LinearToConvPass)
if the model contains `Gemm` ops.  Omitting this step causes `LinearToConvPass`
to silently skip all `Gemm` nodes, producing an incomplete transformation.

# (Continued from above — inside the same script)
    
    # Step 4 (Experimental): Convert MatMul/Gemm to Conv1x1.
    #   NOTE: This is experimental — validate outputs before production use.
    #   If the model has Gemm ops, CanonizeGemmPass MUST be run first to normalize
    #   all Gemm variants into a canonical form that LinearToConvPass can handle.
    CanonizeGemmPass().apply(ctx)   # required before LinearToConvPass if model has Gemm
    converted = LinearToConvPass().apply(ctx)
    print(f"Converted {converted} linear op(s) to Conv1x1.")
    
    # Step 5: Re-run shape inference to propagate updated shapes.
    DeadCodeRemovalRewriter().apply(ctx)
    ShapeInference().apply(ctx)
    
    # Step 6: Apply layout optimization to clean up inserted transposes.
    LayoutOptRewriter().apply(ctx)
    DeadCodeRemovalRewriter().apply(ctx)
    
    # Step 7: Export.
    ctx.export("./output_pipeline", prefix="optimized_llm_model")
    Copy to clipboard

* * *

## [Custom Pass Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id7)

### [Writing and Integrating a Custom Pass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#id8)

The optimizer framework is designed to be extensible. You can write a custom pass
by subclassing [`BasePredicatePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass)
(for per-node match→rewrite patterns) or
[`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass) (for passes that
traverse the whole graph).

**Key points when writing a custom pass:**

- Update tensor encoding metadata (`v.meta["extra_info"]`) whenever you create or
replace a tensor. Use the base-class helpers
[`mark_value_as_copy()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass.mark_value_as_copy) and
[`mark_value_as_slice()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass.mark_value_as_slice) to
propagate encodings correctly.
- Examine the built-in passes (e.g.
`qairt.optimizer.onnx.passes.cleaning`) as reference implementations.
- Return the number of modifications from `apply()` / `True` from `rewrite()`
so that the framework can decide whether to run the pass again.

The example below shows a custom pass that removes `Dropout` nodes (which are
typically no-ops during inference) by bypassing them:

import onnx_ir as ir
    
    from qairt.optimizer.onnx import GraphContext
    from qairt.optimizer.onnx.passes import BasePredicatePass, MatchInfoProtocol
    from qairt.optimizer.onnx.passes.cleaning import DeadCodeRemovalRewriter
    from qairt.optimizer.onnx.passes import ShapeInference

    class RemoveDropoutPass(BasePredicatePass):
        """Remove Dropout nodes; replace their primary output with their input."""
    
        def match(self, graph: ir.Graph, node: ir.Node) -> bool:
            # Match any Dropout node in default opset.
            return node.op_type == "Dropout" and node.domain == ""
    
        def rewrite(
            self,
            graph: ir.Graph,
            node: ir.Node,
            match_info: MatchInfoProtocol | None = None,
        ) -> bool:
            data_input = node.inputs[0]
            data_output = node.outputs[0]  # primary output: the (possibly dropped) data
    
            if data_input is None:
                return False
    
            # Copy encoding / metadata from the input to the bypassed output.
            self.mark_value_as_copy(graph, data_input, data_output)
    
            # Redirect all consumers of the Dropout output to the Dropout input.
            # Do NOT call graph.remove() here — let DeadCodeRemovalRewriter clean
            # up unreachable nodes after the pass completes.
            data_output.replace_all_uses_with(data_input)
            return True

    # --- Integrating the custom pass into a pipeline ---
    
    ctx = GraphContext.from_files(
        model_path="model_with_dropout.onnx",
        encodings_path="model_with_dropout.encodings",
    )
    
    # Apply the custom pass.
    removed = RemoveDropoutPass().apply(ctx)
    print(f"Removed {removed} Dropout node(s).")
    
    # Clean up any tensors that are now unused.
    DeadCodeRemovalRewriter().apply(ctx)
    
    # Re-run shape inference so subsequent passes have accurate shapes.
    ShapeInference().apply(ctx)
    
    # Export.
    ctx.export("./output_custom", prefix="model_no_dropout")
    Copy to clipboard

Last Published: Jun 19, 2026

[Previous Topic
Verifiability](https://docs.qualcomm.com/bundle/publicresource/80-87189-2/topics/qairt-optimizer-overview.md) [Next Topic
Utilities](https://docs.qualcomm.com/bundle/publicresource/80-87189-2/topics/guides.md)