# Classes & Passes

API reference for [`GraphContext`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext) and the
individual transformation passes.

Note

This page is intended for users who:

- need **finer control** over the optimization pipeline than the
high-level functions provide;
- want to **author their own custom pass** and compose it with (or
without) the built-in passes to build a new pipeline; or
- want to understand the **internal implementation** of the optimizer.

For most use cases (MHA→SHA conversion, model splitting, MoE adaptation,
AR/CL rewriting), start with the high-level functions in
[Functions](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt-optimizer-api).  For background on what the optimizer is,
when to apply each transformation, the framework concepts
(`GraphContext` + Passes), and the extensibility/composability/
verifiability properties of the framework, see
[ONNX Optimizer Overview](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#qairt-optimizer-overview).

- [Classes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#classes)

    - [Graph Context](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#graph-context)

        - [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#graphcontext)
    - [Export Result](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#export-result)

        - [ExportedFiles](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#exportedfiles)
        - [ExportedUseCase](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#exportedusecase)
    - [Axis Denotations](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#axis-denotations)

        - [AxisDenotationConfig](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#axisdenotationconfig)
        - [AxisDenotationSeedRule](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#axisdenotationseedrule)
        - [AxisDenotation](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#axisdenotation)
- [Passes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#passes)

    - [Shape Inference](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#shape-inference)

        - [ShapeInference](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#shapeinference)
    - [MHA to SHA Conversion](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#mha-to-sha-conversion)

        - [MHA2SHARewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#mha2sharewriter)
        - [M2sStartPoint](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#m2sstartpoint)
    - [Layout Optimization](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#layout-optimization)

        - [LayoutOptRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#layoutoptrewriter)
        - [ProtectLayoutSensitiveOps](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#protectlayoutsensitiveops)
        - [UnProtectLayoutSensitiveOps](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#unprotectlayoutsensitiveops)
        - [SimplifyReshapeTransposeSeqRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#simplifyreshapetransposeseqrewriter)
    - [I/O Shape Rewriting](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#i-o-shape-rewriting)

        - [IOShapeRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#ioshaperewriter)
    - [Cleaning](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#cleaning)

        - [DeadCodeRemovalRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#deadcoderemovalrewriter)
        - [DeadWeightRemovalRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#deadweightremovalrewriter)
        - [DeadFunctionRemovalRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#deadfunctionremovalrewriter)
    - [I/O Protection](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#i-o-protection)

        - [ProtectIO](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#protectio)
        - [UnprotectIO](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#unprotectio)
    - [Simplification](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#simplification)

        - [ParallelizeSerialOpsRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#parallelizeserialopsrewriter)
    - [Splitters](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#splitters)

        - [LLMSplitter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#llmsplitter)
    - [Experimental Passes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#experimental-passes)

        - [LinearToConvPass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#lineartoconvpass)
        - [CanonizeGemmPass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#canonizegemmpass)
    - [Base Classes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#base-classes)

        - [BasePass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#basepass)
        - [BasePredicatePass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#basepredicatepass)

## [Classes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id4)

### [Graph Context](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id5)

#### [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id6)

- *class* qairt.optimizer.onnx.GraphContext(*model\_ir: Model*, *named\_encodings: Optional[dict[str, dict]] = None*, *named\_safetensors: Optional[dict[str, dict]] = None*, *updatable\_tensors: Optional[list[str]] = None*, *naming\_prefix: str = 'opt'*, *\_skip\_shape\_infer: bool = False*)

    - Bases: `object`

Central data structure for ONNX graph optimization operations.

[`GraphContext`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext) wraps an `onnx_ir.Model` together with its associated
metadata — quantization encodings, LoRA safetensors, and updatable tensor names.
All optimizer passes operate on a `GraphContext` rather than directly on the
underlying ONNX model.

**Recommended construction** — use [`from_files()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext.from_files) to load a model from disk:

from qairt.optimizer.onnx import GraphContext
    
    ctx = GraphContext.from_files("model.onnx")
    # with quantization encodings:
    ctx = GraphContext.from_files("model.onnx", encodings_path="model.encodings")
    Copy to clipboard

Shape inference is run automatically during construction (via
[`ShapeInference`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.ShapeInference)).

After applying passes, export all artifacts with [`export()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext.export):

ctx.export("/path/to/output/dir", prefix="model_optimized")
    Copy to clipboard

All graph passes operate on this context object rather than directly on the graph.

**Key attributes** (useful when writing custom passes):

- `model_ir` (`onnx_ir.Model`) — the full ONNX IR model (includes all
graphs and functions).
- `graph_ir` (`onnx_ir.Graph`) — the main graph of the model.  Passes
iterate over its nodes (`for node in ctx.graph_ir:`) and access its
initializers (`ctx.graph_ir.initializers`), inputs/outputs, and metadata
(`ctx.graph_ir.meta`).

See also

[ONNX Optimizer Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#qairt-optimizer-examples) — worked examples showing how to construct
a `GraphContext`, apply passes and high-level API functions, and export the
optimized model.

- export(*path: str | os.PathLike*, *prefix: str = 'model'*) → ExportedFiles

    - Export all model artifacts to a directory.

Writes the following files to *path*:

- `<prefix>.onnx` — the optimized ONNX model.
- `<prefix>.data` — external weight data sidecar.
- `<prefix>.encodings` — quantization encodings JSON (if present).
- `<adapter_name>.safetensors` — LoRA adapter weights (one file per adapter).
- `lora_importer_config.yaml` and `lora_tensor_names.txt` (if LoRA adapters
are present).

- Parameters

    - - **path** – Directory where artifacts are saved.  Created if it does not exist.
- **prefix** – Basename prefix for the ONNX model and encodings files.
Defaults to `"model"`.

- Returns

    - Paths to all written files.

- Return type

    - [ExportedFiles](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.ExportedFiles)

- *classmethod* from\_files(*model\_path: str | os.PathLike*, *encodings\_path: Optional[Union[str, PathLike]] = None*, *lora\_adapters\_path: Optional[Union[str, PathLike]] = None*, *lora\_tensor\_names\_path: Optional[Union[str, PathLike]] = None*, *naming\_prefix: str = 'opt'*, *\*\*kwargs*)

    - Load a model and its metadata from disk and return an initialized GraphContext.

This is the **recommended** way to create a [`GraphContext`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext) for external
users.

- Parameters

    - - **model\_path** – Path to the ONNX model file (`.onnx`).
- **encodings\_path** – Path to an AIMET quantization encodings JSON file.
Supported format versions: v0.6.1 and v1.0.0.  Pass `None` for
non-quantized models.
- **lora\_adapters\_path** –

    Path to a LoRA adapters YAML config file
(`lora_importer_config`).

    YAML schema:

use_case:
          - name: <usecase_1/adapter_1 name>
            lora_weights: <path to safetensor file for adapter_1>
            quant_overrides: <path to AIMET encodings file for adapter_1>
          - name: <usecase_2/adapter_2 name>
            lora_weights: <path to safetensor file for adapter_2>
            quant_overrides: <path to AIMET encodings file for adapter_2>
          ...
        Copy to clipboard
- **lora\_tensor\_names\_path** – Path to .txt file with updatable LoRA tensor names.

- Returns

    - Initialized graph context.

- Return type

    - [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)

- get\_encodings()

    - Extract quantization encodings from the graph.

Collects per-tensor encoding information stored in each tensor’s
`meta["extra_info"]` and returns it grouped by encoding-set name.

- Returns

    - Encoding information keyed by encoding-set
name (e.g. `"base"` for base-model encodings).

- Return type

    - dict[str, GraphEncodingInfo]

- get\_onnx\_proto()

    - Serialize the model into onnx.ModelProto

- get\_safetensors()

    - Extract LoRA safetensor weights from the graph.

- Returns

    - Safetensor dictionaries keyed by adapter name.  Each
inner dict maps tensor name to a NumPy array.

- Return type

    - dict[str, dict]

- get\_tracing\_info(*merged=False*)

    - Get tracing information of all transformations recorded

- Parameters

    - **merged** – whether to merge the chainable transformations into one

- get\_updatable\_tensor\_names()

    - Return the list of updatable LoRA tensor names in the graph.

- Returns

    - Tensor names that were marked as updatable when the
[`GraphContext`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext) was constructed.

- Return type

    - list[str]

- save\_onnx(*path: str*, *external\_data: Optional[str] = None*)

    - Save the model to an ONNX file with external data.

This method is more memory-efficient than serializing via
`onnx.save(ctx.get_onnx_proto())` because it avoids loading all weight
tensors into memory at once.

- Parameters

    - - **path** – Destination file path for the `.onnx` file.  The parent directory
is created automatically if it does not exist.
- **external\_data** – Filename (basename only) for the external data sidecar file.
Defaults to `<model_basename>.data` (e.g. `model.data` for
`model.onnx`).

- save\_tracing\_info(*path: str*, *merged=False*)

    - Save tracing information to the file

- Parameters

    - - **path** – the path to save the tracing information
- **merged** – whether to merge the chainable transformations into one

### [Export Result](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id7)

The [`export()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext.export) method returns an
`ExportedFiles` object that records the paths of every artifact written to disk.
Each LoRA adapter’s exported files are listed in `ExportedFiles.use_cases`
as [`ExportedUseCase`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.ExportedUseCase) objects.

#### [ExportedFiles](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id8)

- *class* qairt.optimizer.onnx.ExportedFiles(*\**, *onnx\_path: Path*, *data\_path: Union[Path, Path]*, *encodings\_path: Optional[Path] = None*, *use\_cases: list[qairt.optimizer.onnx.graph.ExportedUseCase] = []*, *lora\_tensor\_names: Optional[Path] = None*, *lora\_importer\_config: Optional[Path] = None*, *lora\_transform\_metadata\_path: Optional[Path] = None*)

    - Bases: `BaseModel`

- *field* data\_path*: Union[Path, Path]*  *[Required]*

    - 

- *field* encodings\_path*: Optional[Path]*  *= None*

    - 

- *property* info*: qairt.optimizer.onnx.graph.ExportedFileInfo | None*

    - Read-only access to export metadata, if populated.

- *field* lora\_importer\_config*: Optional[Path]*  *= None*

    - 

- *field* lora\_tensor\_names*: Optional[Path]*  *= None*

    - 

- *field* lora\_transform\_metadata\_path*: Optional[Path]*  *= None*

    - 

- model\_post\_init(*context: Any*, */*) → None

    - This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that’s what pydantic-core passes when calling it.

- Parameters

    - - **self** – The BaseModel instance.
- **context** – The context.

- *field* onnx\_path*: Path*  *[Required]*

    - - Constraints

    - - **path\_type** = file

- *field* use\_cases*: list[qairt.optimizer.onnx.graph.ExportedUseCase]*  *= []*

    -

#### [ExportedUseCase](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id9)

- *class* qairt.optimizer.onnx.ExportedUseCase(*\**, *name: str*, *safetensors: Path*, *encodings: Optional[Path] = None*)

    - Bases: `BaseModel`

- *field* encodings*: Optional[Path]*  *= None*

    - 

- *field* name*: str*  *[Required]*

    - 

- *field* safetensors*: Path*  *[Required]*

    - - Constraints

    - - **path\_type** = file

### [Axis Denotations](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id10)

Axis denotations label individual tensor dimensions with semantic meaning (e.g.
batch, [AR](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-AR), [CL](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-CL)).
They are used internally by
[`IOShapeRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.io_shape_rewriter.IOShapeRewriter)
and the high-level AR/CL API functions to locate and update the correct dimensions
throughout the graph.

#### [AxisDenotationConfig](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id11)

- *class* qairt.optimizer.onnx.AxisDenotationConfig(*input\_ids\_name\_pattern: str = 'input\_ids'*, *inputs\_embeds\_name\_pattern: str = '(input|inputs)\_embeds'*, *hidden\_states\_name\_pattern: str = '(target\_)?hidden\_states'*, *layer\_output\_name\_pattern: str = '/?model\_(layers\_\\d+\_Add/Add|embed\_tokens/Gather)\_output\_0'*, *position\_ids\_name\_pattern: str = '(swa\_)?position\_ids(\_sin|\_cos)?'*, *key\_cache\_name\_pattern: str = 'past\_key\_(\\d)+\_in'*, *value\_cache\_name\_pattern: str = 'past\_value\_(\\d)+\_in'*, *attention\_mask\_name\_pattern: str = 'attention\_mask'*, *cache\_index\_name\_pattern: str = '(swa\_)?cache\_index'*, *swa\_mask\_name\_pattern: str = 'swa\_attention\_mask'*, *swa\_key\_name\_pattern: str = 'swa\_(past\_)?key\_(\\d)+\_in'*, *swa\_value\_name\_pattern: str = 'swa\_(past\_)?value\_(\\d)+\_in'*, *recurrent\_state\_name\_pattern: str = 'recurrent\_state\_(\\d)+\_in'*, *conv\_state\_name\_pattern: str = 'conv\_state\_(\\d)+\_in'*, *linear\_attn\_mask\_name\_pattern: str = 'linear\_attn\_mask'*, *num\_last\_accepted\_name\_pattern: str = 'num\_last\_accepted'*, *conv\_masks\_name\_pattern: str = 'conv\_masks'*, *draft\_tree\_mask\_name\_pattern: str = 'draft\_tree\_mask'*, *transposed\_key\_cache: bool = True*, *lora\_alpha\_name\_pattern: str = 'lora\_alpha'*, *custom\_seed\_rules: list[qairt.optimizer.onnx.passes.axis\_denotation\_infer.config.AxisDenotationSeedRule] = &lt;factory&gt;*)

    - Bases: `PassConfig`

Configuration for axis denotation inference passes

This config specifies regex patterns to match ONNX graph input names and determine
what each dimension represents. It’s used to bootstrap the denotation inference
process by identifying known tensor patterns in LLM models

**How it works**

- The `name_pattern` is compiled as a Python regex pattern.
- If a graph input name matches the pattern, the tensor is initialised with the
given denotations.
- These initial denotations propagate through the graph via inference passes.

The config includes built-in patterns for common LLM inputs (`input_ids`,
`attention_mask`, KV caches, etc.) and supports custom patterns for non-standard
models via `custom_seed_rules`.

**Pattern matching priority**

- Custom seed rules (`custom_seed_rules`) are checked **first**.
- Rules are evaluated in the order they appear in the list.
- The first matching rule is used; subsequent rules are skipped.
- If no custom rule matches, built-in patterns are tried.
- This allows overriding built-in patterns for non-standard naming conventions.

Example:

# Basic usage with defaults
    config = AxisDenotationConfig()
    
    # Custom model with non-standard names
    config = AxisDenotationConfig(
        custom_seed_rules=[
            AxisDenotationSeedRule(
                name_pattern=r"my_custom_input",
                denotations=[AxisDenotation.BATCH, AxisDenotation.SEQ_LENGTH],
            ),
            AxisDenotationSeedRule(
                name_pattern=r"my_cache_\d+",
                denotations=[
                    AxisDenotation.BATCH,
                    AxisDenotation.UNKNOWN,
                    AxisDenotation.PAST_SEQ_LENGTH,
                    AxisDenotation.UNKNOWN,
                ],
            ),
        ],
    )
    Copy to clipboard

- All pattern attributes use Python regex syntax and are case-insensitive

    - 

- Patterns must match the entire input name

    - - Type

    - fullmatch, not search

- attention\_mask\_name\_pattern*: str*  *= 'attention\_mask'*

    - 

- cache\_index\_name\_pattern*: str*  *= '(swa\_)?cache\_index'*

    - 

- conv\_masks\_name\_pattern*: str*  *= 'conv\_masks'*

    - 

- conv\_state\_name\_pattern*: str*  *= 'conv\_state\_(\\d)+\_in'*

    - 

- custom\_seed\_rules*: list[qairt.optimizer.onnx.passes.axis\_denotation\_infer.config.AxisDenotationSeedRule]*

    - 

- draft\_tree\_mask\_name\_pattern*: str*  *= 'draft\_tree\_mask'*

    - 

- hidden\_states\_name\_pattern*: str*  *= '(target\_)?hidden\_states'*

    - 

- input\_ids\_name\_pattern*: str*  *= 'input\_ids'*

    - 

- inputs\_embeds\_name\_pattern*: str*  *= '(input|inputs)\_embeds'*

    - 

- key\_cache\_name\_pattern*: str*  *= 'past\_key\_(\\d)+\_in'*

    - 

- layer\_output\_name\_pattern*: str*  *= '/?model\_(layers\_\\d+\_Add/Add|embed\_tokens/Gather)\_output\_0'*

    - 

- linear\_attn\_mask\_name\_pattern*: str*  *= 'linear\_attn\_mask'*

    - 

- lora\_alpha\_name\_pattern*: str*  *= 'lora\_alpha'*

    - 

- num\_last\_accepted\_name\_pattern*: str*  *= 'num\_last\_accepted'*

    - 

- position\_ids\_name\_pattern*: str*  *= '(swa\_)?position\_ids(\_sin|\_cos)?'*

    - 

- recurrent\_state\_name\_pattern*: str*  *= 'recurrent\_state\_(\\d)+\_in'*

    - 

- swa\_key\_name\_pattern*: str*  *= 'swa\_(past\_)?key\_(\\d)+\_in'*

    - 

- swa\_mask\_name\_pattern*: str*  *= 'swa\_attention\_mask'*

    - 

- swa\_value\_name\_pattern*: str*  *= 'swa\_(past\_)?value\_(\\d)+\_in'*

    - 

- transposed\_key\_cache*: bool*  *= True*

    - 

- value\_cache\_name\_pattern*: str*  *= 'past\_value\_(\\d)+\_in'*

    -

#### [AxisDenotationSeedRule](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id12)

- *class* qairt.optimizer.onnx.AxisDenotationSeedRule(*name\_pattern: str*, *denotations: list[qairt.optimizer.onnx.utils.ir\_extra\_info.AxisDenotation]*)

    - Bases: `object`

Seed rule for bootstrapping axis denotations based on ONNX graph input name patterns

This rule maps tensor name patterns to lists of axis denotations. When an ONNX graph
input’s name matches the pattern (using regex fullmatch), the specified denotations
are assigned to its axes. This bootstraps the denotation inference process.

**How seed rules work**

- The `name_pattern` is compiled as a Python regex pattern.
- If a graph input name matches the pattern, `denotations` is assigned to that tensor.
- The length of `denotations` must equal the tensor’s rank.

- Parameters

    - - **name\_pattern** – Regex pattern to match ONNX graph input names (case-insensitive)
Uses Python regex syntax. The pattern must match the entire name
- **denotations** – List of AxisDenotation values to assign to matching tensors
Length must equal the tensor’s rank

Example:

# For a custom input named "my_input_0" with shape [batch, seq_len]:
    AxisDenotationSeedRule(
        name_pattern=r"my_input_\d+",
        denotations=[AxisDenotation.BATCH, AxisDenotation.SEQ_LENGTH],
    )
    
    # For KV cache inputs with shape [batch, heads, past_seq, head_dim]:
    AxisDenotationSeedRule(
        name_pattern=r"past_key_\d+_in",
        denotations=[
            AxisDenotation.BATCH,
            AxisDenotation.UNKNOWN,
            AxisDenotation.PAST_SEQ_LENGTH,
            AxisDenotation.UNKNOWN,
        ],
    )
    Copy to clipboard

- denotations*: list[qairt.optimizer.onnx.utils.ir\_extra\_info.AxisDenotation]*

    - 

- name\_pattern*: str*

    -

#### [AxisDenotation](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id13)

- *class* qairt.optimizer.onnx.AxisDenotation(*value*)

    - Bases: `str`, `Enum`

Axis denotation values for tensor dimensions

Denotation describes what each dimension represents in the context of LLM models.
This aligns with ONNX’s dimension denotation concept, specialized for LLMs.

- Values:
    - BATCH: Batch size dimension
SEQ\_LENGTH: Current sequence length (autoregressive length)
PAST\_SEQ\_LENGTH: Past sequence length in KV cache
CONTEXT\_LENGTH: Global context length (PAST\_SEQ\_LENGTH + SEQ\_LENGTH)
SLIDING\_CONTEXT\_LENGTH: Sliding-window context length for SWA models
UNKNOWN: Dimension meaning is unknown or not yet inferred

- BATCH *= 'BATCH'*

    - 

- CONTEXT\_LENGTH *= 'CONTEXT\_LENGTH'*

    - 

- PAST\_SEQ\_LENGTH *= 'PAST\_SEQ\_LENGTH'*

    - 

- SEQ\_LENGTH *= 'SEQ\_LENGTH'*

    - 

- SLIDING\_CONTEXT\_LENGTH *= 'SLIDING\_CONTEXT\_LENGTH'*

    - 

- UNKNOWN *= 'UNKNOWN'*

    -

* * *

## [Passes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id14)

Passes are the building blocks of the optimizer.  For most common transformations,
prefer the [high-level API functions](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt-optimizer-api).  Use passes
directly when you need fine-grained control over the optimization pipeline or when
writing custom transformations.

Tip

When composing passes manually, run
[`ShapeInference`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.ShapeInference)
after any pass that adds new tensors or modifies shapes so that subsequent passes
have accurate shape information.

### [Shape Inference](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id15)

#### [ShapeInference](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id16)

- *class* qairt.optimizer.onnx.passes.ShapeInference(*config: Optional[Config] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Infer and propagate tensor shapes throughout the ONNX graph.

This pass runs ONNX shape inference combined with constant propagation to
populate shape and data-type information on all intermediate tensors.  It is
called **automatically** during `GraphContext`
initialization (unless `_skip_shape_infer=True` is passed), so most users
never need to invoke it directly.

Apply it explicitly after a transformation that introduces new tensors or
modifies shapes so that subsequent passes have accurate shape information:

from qairt.optimizer.onnx.passes.shape_infer import ShapeInference
    
    ShapeInference().apply(ctx)
    Copy to clipboard

Config options:

- `constant_folding_size_threshold` (int, default 1 MB) — external tensors
smaller than this byte limit are preloaded into memory before constant folding.
Increase to fold larger constant sub-graphs; decrease to limit memory use.
- `overwrite_shape` (bool, default `False`) — when `True`, discards all
existing shape annotations and recomputes them from scratch.  This is slower
but guarantees consistency after aggressive graph rewrites.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Inference the shape on the whole graph
Iteratively do the constant propagation and shape inference
until no extra information(shape/type/constant value) can be inferred.

### [MHA to SHA Conversion](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id17)

Multi-Head Attention (MHA) → Single-Head Attention (SHA) transformation.

This pass splits each multi-head attention block into individual per-head sub-graphs.
After the structural split, redundant `Transpose` and `Reshape` sequences are
present in the graph.  **Layout optimization** (removing those sequences to produce a
clean, efficient graph) is a separate step.

Note

`MHA2SHARewriter` performs the structural split only.  It does **not**
apply layout optimization.  Always follow it with
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter)
to simplify the redundant `Transpose` / `Reshape` sequences inserted during the
split — omitting this step leaves the graph with unnecessary operations and may
hurt on-device performance.

For the complete, recommended workflow (structural split + layout optimization in
one call), use [`convert_mha_to_sha()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.convert_mha_to_sha) from the
[Functions](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt-optimizer-api) page:

from qairt.optimizer.onnx import GraphContext, convert_mha_to_sha
    
    ctx = GraphContext.from_files("model.onnx", "model.encodings")
    convert_mha_to_sha(ctx)
    ctx.save("./output/model_sha.onnx")
    Copy to clipboard

Use `MHA2SHARewriter` directly only when you need to interleave custom
passes between the structural split and layout optimization.

#### [MHA2SHARewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id18)

- *class* qairt.optimizer.onnx.passes.MHA2SHARewriter(*config: Optional[MHA2SHAConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Core pass that converts Multi-Head Attention (MHA) to Single-Head Attention (SHA).

This pass splits each multi-head attention block into individual single-head
attention sub-graphs by propagating head-slicing operators (GroupSlice) through
all surrounding operations.  The transformation proceeds in three stages:

1. **pre\_stage** — insert `GroupSlice→Concat` pairs after every QKV MatMul in
each attention block, marking the head boundaries.
2. **proc\_stage** — repeatedly reorder `(X → GroupSlice)` patterns to
`(GroupSlice → X)` until no further reordering is possible, effectively
distributing the per-head slicing to the leaves of the computation.
3. **post\_stage** — clean up auxiliary nodes introduced in earlier stages.

Note

This pass applies the structural MHA→SHA transformation **only**.  Layout
optimization (re-arranging transposes/reshapes for inference efficiency) is
**not** applied automatically.  After running this pass, apply
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter)
to obtain the full benefit.

For the complete end-to-end workflow (MHA→SHA + layout optimization), use
[`convert_mha_to_sha()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.convert_mha_to_sha) instead of this pass directly.

**Config** (see [MHA2SHAConfig fields](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#mha2sha-config-fields) table below):

- `m2s_head_split_map` — mapping from MHA head count to SHA head count,
e.g. `{-1: 1}` to split every head into size-1 heads.
- `m2s_additional_start_points` — list of
[`M2sStartPoint`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.M2sStartPoint) objects
that specify custom tensor patterns from which to begin the MHA→SHA walk for
non-standard attention architectures.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

MHA2SHAConfig fields

| Field | Type | Default | Description |
| --- | --- | --- | --- |
| `m2s_head_split_map` | `dict[int, int]` | `{}` | Maps input MHA head count to output SHA head count.<br>`{-1: 1}` splits all heads to size 1 (wildcard); `{128: 8}`<br>splits 128-head attention to 8 heads.  Empty dict uses the default<br>single-head behaviour. |
| `m2s_additional_start_points` | `list[M2sStartPoint]` | `[]` | Extra start points for non-standard attention architectures where the<br>default QKV MatMul detection does not apply.  See<br>[`M2sStartPoint`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.M2sStartPoint) and the<br>`M2sStartPoint` fields table below. |

#### [M2sStartPoint](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id19)

- *class* qairt.optimizer.onnx.M2sStartPoint(*name\_pattern: str*, *split\_axis: int*, *split\_map: Optional[dict[int, int]] = None*)

    - Bases: `object`

Configuration for custom starting points used in MHA2SHA transformation

This dataclass defines where the MHA2SHA transformation should begin for attention patterns
that don’t follow the standard QKV MatMul pattern

Example:

start_point = M2sStartPoint(
        name_pattern="past_(key|value)_(\d+)_out",
        split_axis=1,  # Head axis for 4D tensors
        split_map={32: 8, -1: 1}  # 32 heads -> 8 heads, others -> 1 head
    )
    Copy to clipboard

- name\_pattern*: str*

    - 

- split\_axis*: int*

    - 

- split\_map*: dict[int, int] | None*  *= None*

    -

M2sStartPoint fields

`M2sStartPoint` describes a custom tensor pattern from which the MHA→SHA walk should begin:

| Field | Type | Default | Description |
| --- | --- | --- | --- |
| `name_pattern` | `str` | *(required)* | Regex pattern matched against tensor names.  Tensors whose names<br>match receive a `GroupSlice` insertion, enabling MHA→SHA for<br>non-standard attention patterns. |
| `split_axis` | `int` | *(required)* | The axis that contains the head dimension.  Typical values:<br>`0` for 3-D tensors `[heads, seq, dim]`, `1` for 4-D tensors<br>`[batch, heads, seq, dim]`. |
| `split_map` | `dict[int, int] | None` | `None` (→ `{-1: 1}`) | Optional head-count mapping for this start point.  Same format as<br>`m2s_head_split_map`.  `None` defaults to `{-1: 1}`. |

### [Layout Optimization](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id20)

Layout optimization passes reorganize `Transpose` and `Reshape` sequences
for more efficient on-device execution.  These passes are typically applied after
[`MHA2SHARewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.MHA2SHARewriter).

#### [LayoutOptRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id21)

- *class* qairt.optimizer.onnx.passes.LayoutOptRewriter(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Entry pass for post-MHA→SHA layout optimization.

After [`MHA2SHARewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.MHA2SHARewriter)
has split multi-head attention into single-head sub-graphs, this pass reorganizes
the resulting Reshape/Transpose sequences into a layout that is more efficient for
on-device inference.

The internal pipeline is:

1. [`ProtectLayoutSensitiveOps`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.layout_opt.ProtectLayoutSensitiveOps)
— insert transpose guards around layout-sensitive ops (e.g. `Conv`) so they
are not accidentally permuted by subsequent passes.
2. [`SimplifyReshapeTransposeSeqRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.layout_opt.SimplifyReshapeTransposeSeqRewriter)
— simplify chains of `Reshape / Transpose / Squeeze / Unsqueeze`.
3. `LayoutBinelewiseRewriter` *(internal, not separately documented —see source code for details)* — pushes `Transpose` ops past binary
element-wise ops (`Add`, `Mul`, etc.) towards the graph leaves so
that transposes from different branches can be merged.
4. `LayoutConcatAfterQKVMatmulsRewriter` *(internal, not separatelydocumented)* — optimises `Concat` ops that immediately follow QKV
MatMuls by pushing `Transpose` ops through them, enabling further
simplification.
5. `SimplifyConcatTransposeRewriter` *(internal, not separatelydocumented)* — folds `Concat → Transpose` patterns into simpler
equivalent forms, reducing the number of ops that reach the device.
6. [`UnProtectLayoutSensitiveOps`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.layout_opt.UnProtectLayoutSensitiveOps)
— restore original layout-sensitive op layout.

Dead-code removal is interleaved between each sub-pass.  This pass is automatically
applied by [`convert_mha_to_sha()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.convert_mha_to_sha) as part of the complete
MHA→SHA workflow.

Tip

You generally do not need to invoke this pass directly.  Use
[`convert_mha_to_sha()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.convert_mha_to_sha) for the full transformation.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

The following passes are used internally by
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter)
and can also be used standalone in custom pipelines:

#### [ProtectLayoutSensitiveOps](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id22)

- *class* qairt.optimizer.onnx.passes.layout\_opt.ProtectLayoutSensitiveOps(*config: Optional[PassConfig] = None*)

    - Bases: `LayoutBasePredicatePass`

Temporarily convert layout-sensitive ops to a channel-last representation.

Ops such as `Conv` operate on a fixed memory layout (NCHW by convention).
Layout optimization passes need to freely push `Transpose` ops past any node,
but doing so naively on `Conv` would produce an incorrect graph.

This pass inserts `Transpose` guards around each `Conv` to rewrite it into a
channel-last (NHWC) variant (renamed to `M2S_Conv_ChnLast`) before the layout
optimization passes run.  After layout optimization completes,
[`UnProtectLayoutSensitiveOps`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.layout_opt.UnProtectLayoutSensitiveOps) restores the original layout.

Note

This pass is automatically applied inside
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter).
You do not need to call it directly unless you are building a custom
optimization pipeline.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

#### [UnProtectLayoutSensitiveOps](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id23)

- *class* qairt.optimizer.onnx.passes.layout\_opt.UnProtectLayoutSensitiveOps(*config: Optional[PassConfig] = None*)

    - Bases: `LayoutBasePredicatePass`

Restore layout-sensitive ops to their original layout after layout optimization.

This pass is the counterpart of [`ProtectLayoutSensitiveOps`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.layout_opt.ProtectLayoutSensitiveOps).  It recognizes
`M2S_Conv_ChnLast` nodes (the channel-last placeholders inserted during
protection) and converts them back to standard `Conv` ops, removing the guard
transposes that were inserted.

Note

Always pair this pass with [`ProtectLayoutSensitiveOps`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.layout_opt.ProtectLayoutSensitiveOps).  Both are
invoked automatically inside
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter).

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

#### [SimplifyReshapeTransposeSeqRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id24)

- *class* qairt.optimizer.onnx.passes.layout\_opt.SimplifyReshapeTransposeSeqRewriter(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Simplify chains of Reshape / Transpose / Squeeze / Unsqueeze operations.

Consecutive sequences of shape-manipulation ops often arise after MHA→SHA
splitting and layout optimization.  Many of these sequences are equivalent to a
shorter sequence (or to a single `Reshape`, or even to a no-op).  This pass
analyses each sequence and replaces it with the minimal equivalent set of ops,
reducing the number of operations that need to execute on-device.

The pass is automatically invoked multiple times by
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter)
as part of the full layout optimization pipeline.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

### [I/O Shape Rewriting](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id25)

Rewrite [AR](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-AR) (sequence length) and [CL](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-CL) (context length) values throughout the model.

Note

For typical use cases prefer the convenience functions
[`change_seq_and_context_length()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.change_seq_and_context_length),
[`change_seq_length()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.change_seq_length), and
[`change_context_length()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.change_context_length).

#### [IOShapeRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id26)

- *class* qairt.optimizer.onnx.passes.io\_shape\_rewriter.IOShapeRewriter(*config: Optional[IOShapeRewriterConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Updates model I/O shapes and constants based on new sequence/context lengths.

This pass uses axis denotations to intelligently update tensor shapes throughout
the graph when changing sequence length ([AR](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-AR)) or context length ([CL](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-CL)).

Steps:

1. Run `AxisDenotationInference`
to label all tensor axes with semantic meaning.
2. Compute the original sequence length and context length from the graph.
3. Validate the new sequence/context length values from config.
4. Update graph inputs and outputs based on axis denotations.
5. Update shape-constant tensors on `Reshape` / `Expand` and similar nodes.
6. Store updated lengths in graph metadata.
7. Run cleanup passes (dead code and weight removal).
8. Re-run shape inference to recompute all intermediate shapes.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

- *static* validate\_config(*seq\_length: int*, *context\_length: int*, *sliding\_context\_length: Optional[int] = None*) → None

    - Validate that 1 &lt;= seq\_length &lt;= context\_length - 1

- Parameters

    - - **seq\_length** – Sequence length to validate
- **context\_length** – Context length to validate
- **sliding\_context\_length** – Optional sliding-window context length to validate (SWA models).
When provided, it must satisfy `1 <= sliding_context_length <= context_length`.
The sliding window is not derived from `context_length - seq_length` (so no
relationship with `seq_length` is imposed), but it can never exceed the global
context: a window cannot attend to more positions than the context holds.
`sliding_context_length == context_length` is allowed (windowed attention then
degenerates to full attention).

- Raises

    - **ValueError** – If validation fails

IOShapeRewriterConfig fields

| Field | Type | Default | Description |
| --- | --- | --- | --- |
| `new_seq_length` | `int | None` | `None` | New sequence length ([AR](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-AR)) to apply.  `None` keeps the<br>original value (sequence length is unchanged). |
| `new_context_length` | `int | None` | `None` | New context length ([CL](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-overview.html#term-CL)) to apply.  `None` keeps the<br>original value. |
| `axis_denotation_config` | `AxisDenotationConfig | None` | `None` | Optional axis denotation configuration.  `None` uses the default<br>built-in patterns.  Provide a custom config if your model uses<br>non-standard tensor names. |

### [Cleaning](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id27)

Passes that remove unused elements from the graph.  It is good practice to run one
or more cleaning passes after a transformation to keep the graph compact.

#### [DeadCodeRemovalRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id28)

- *class* qairt.optimizer.onnx.passes.cleaning.DeadCodeRemovalRewriter(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Remove nodes whose outputs are never consumed.

A node is considered dead if every one of its output tensors has no consumers
and is not a graph output.  The pass iterates the graph in reverse topological
order so that removing a node may expose its producers as dead in the same pass.

Note

The graph must be in topological order before this pass is applied.
Most transformation passes leave the graph topologically sorted, so this
is typically satisfied automatically.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Removes dead code from the graph.
Assume graph.nodes are topologically sorted

Returns the number of nodes removed.

#### [DeadWeightRemovalRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id29)

- *class* qairt.optimizer.onnx.passes.cleaning.DeadWeightRemovalRewriter(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Remove initializers (weights) that are not consumed by any node.

After graph transformations, some weight tensors may no longer be referenced by
any operation.  This pass removes them from the graph’s initializer table,
reducing model size and memory footprint.

Run this pass after any transformation that deletes or replaces nodes that had
dedicated weight inputs (e.g. after
[`DeadCodeRemovalRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.cleaning.DeadCodeRemovalRewriter)).

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Removes dead weight from the graph.

Returns the number of weights removed.

#### [DeadFunctionRemovalRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id30)

- *class* qairt.optimizer.onnx.passes.cleaning.DeadFunctionRemovalRewriter(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Remove ONNX local functions that are never called.

ONNX models can define reusable sub-graphs as *local functions*.  After
transformations such as function inlining or model splitting, some of these
functions may no longer be referenced by any node in the main graph or any
other function.  This pass scans all reachable call sites (including nested
function calls) and removes unreferenced functions from the model.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Removes dead function from the graph.

Returns the number of functions removed.

### [I/O Protection](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id31)

Use these passes when a transformation may inadvertently rename graph output
tensors (for example, when replacing an output tensor with a new tensor).
Built-in passes do not rename graph outputs, so these passes are typically only
needed when writing **custom passes** that call `safe_replace_all_uses_with()`
on a graph output tensor.

#### [ProtectIO](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id32)

- *class* qairt.optimizer.onnx.passes.ProtectIO(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Preserve graph output names against renaming by downstream passes.

Some transformation passes rename tensors as a side effect (e.g. when
replacing a node’s output with a new tensor).  If these tensors happen to be
graph outputs, the final model will have different output names, which can break
downstream tools that identify outputs by name.

`ProtectIO` renames all graph outputs with a `.protect` suffix and stores the
original names on a stack in graph metadata.  Once the sensitive passes have run,
call [`UnprotectIO`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.UnprotectIO) to restore the original names.

This pass is **re-entrant**: nested `ProtectIO` / [`UnprotectIO`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.UnprotectIO) pairs are
supported and work correctly as a stack:

ProtectIO().apply(ctx)
    SomePass().apply(ctx)       # may rename outputs internally
    ProtectIO().apply(ctx)      # second, nested protection
    AnotherPass().apply(ctx)
    UnprotectIO().apply(ctx)    # restores inner protection
    UnprotectIO().apply(ctx)    # restores outer protection
    Copy to clipboard

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Rename all outputs of the graph to protect them
Stores original names in graph metadata

#### [UnprotectIO](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id33)

- *class* qairt.optimizer.onnx.passes.UnprotectIO(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Restore graph output names after [`ProtectIO`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.ProtectIO) protection.

Pops the most recent set of original output names from the graph metadata stack
and restores them on the current graph outputs.  Must be called exactly once for
each preceding [`ProtectIO`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.ProtectIO) call.

Raises `RuntimeError` if called without a matching [`ProtectIO`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.ProtectIO), or
if the number of graph outputs has changed since protection was applied.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Rename all outputs of the graph to their original names
Retrieves original names from graph metadata

### [Simplification](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id34)

#### [ParallelizeSerialOpsRewriter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id35)

- *class* qairt.optimizer.onnx.passes.simplification.ParallelizeSerialOpsRewriter(*config: Config | None = None*)

    - Bases: [`BasePredicatePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass)

Transform a left-skewed chain of serial associative binary ops into a balanced
binary tree, enabling parallel execution on device.

Example transformation:

# Before (serial / left-skewed)
    t0 = op(x0, x1)
    t1 = op(t0, x2)
    out = op(t1, x3)
    
    # After (balanced tree)
    a0 = op(x0, x1)
    a1 = op(x2, x3)
    out = op(a0, a1)
    Copy to clipboard

Constraints:

- Only associative binary ops are supported (`Add`, `Mul`, `Max`, `Min`,
`And`, `Or`, `Xor`).
- All nodes in the chain must have the same encodings (or no encodings).

- ASSOCIATIVE\_OPS *= {'Add', 'And', 'Max', 'Min', 'Mul', 'Or', 'Xor'}*

    -

ParallelizeSerialOpsRewriter.Config fields

| Field | Type | Default | Description |
| --- | --- | --- | --- |
| `op_types` | `tuple[str, ...]` | `("Add","Mul","Max","Min","And","Or","Xor")` | The set of associative binary op types to parallelize.  Must be a subset<br>of `{"Add","Mul","Max","Min","And","Or","Xor"}`. |
| `op_num_threshold` | `int` | `3` | Minimum serial chain length required to trigger rewriting.  Chains<br>shorter than this are left unchanged. |

### [Splitters](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id36)

Split a large LLM into multiple sequential sub-models.

#### [LLMSplitter](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id37)

- *class* qairt.optimizer.onnx.passes.splitters.LLMSplitter(*config: LLMSplitterConfig*)

    - Bases: `BaseGraphSplitter`

Split a large LLM ONNX model into multiple smaller sub-models.

Many LLM models are too large to fit on a single device or require pipelining
across multiple device partitions.  `LLMSplitter` cuts the model at residual-add
boundaries (i.e. the add operations at the end of each transformer layer) to
produce `N` sequential sub-models whose outputs feed directly into the inputs of
the next.

The number and location of splits are controlled by `LLMSplitterConfig`
(see the **LLMSplitterConfig fields** table in the API reference).
Optional embedding and language-model head splits are also supported.

The recommended way to invoke this functionality is through
[`split_llm()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-api.html#qairt.optimizer.onnx.split_llm), which handles config construction
and encoding propagation:

from qairt.optimizer.onnx import split_llm
    
    splits = split_llm(ctx, num_splits=3)
    for i, part in enumerate(splits):
        part.export("./output", prefix=f"model_part_{i+1}")
    Copy to clipboard

If you need lower-level control, use this class directly:

from qairt.optimizer.onnx.passes.splitters import LLMSplitter
    from qairt.optimizer.onnx.passes.splitters.config import LLMSplitterConfig
    
    config = LLMSplitterConfig(num_splits=3)
    splitter = LLMSplitter(config)
    split_ctxs = splitter.split(ctx)
    Copy to clipboard

- split(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → list[[qairt.optimizer.onnx.GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)]

    - Splits the given ONNX model into multiple sub-models.
This function splits the model in-place into the specified number of sub-models
It supports splitting embeddings and lm\_head
:param ctx: The GraphContext to split

- Returns

    - List of GraphContext instances for each split

- Return type

    - List[[GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)]

Model topology and valid splitting points:

│ ←── layers[0] ──→ │           │ ←── layers[-1] ──→ │
    embed ─┬── add0 ─┬── add1 ──  ···  ─┬── add(n-2) ─┬── add(n-1) ── lmhead
           └─norm─attn─┘ └─norm─ffn─┘     └─norm─attn─┘  └─norm─ffn─┘
           ↑             ↑                 ↑                            ↑
                     valid splitting points
    Copy to clipboard

LLMSplitterConfig fields

| Field | Type | Default | Description |
| --- | --- | --- | --- |
| `num_splits` | `int` | `1` | Number of sequential sub-models to produce. |
| `split_embedding` | `bool` | `False` | If `True`, place the embedding layer in its own split. |
| `split_lm_head` | `bool` | `False` | If `True`, place the language model head in its own split. |

### [Experimental Passes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id38)

Warning

The passes in this section are **experimental**.  They are not guaranteed to be
general, stable, or free of bugs.  Validate results carefully before using them in
a production workflow.

#### [LinearToConvPass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id39)

- *class* qairt.optimizer.onnx.passes.experimental.LinearToConvPass(*config: Optional[Config] = None*)

    - Bases: [`BasePredicatePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass)

**Experimental** — convert `MatMul` / `Gemm` (optionally followed by `Add`) to a `Conv 1×1`.

Warning

This pass is **experimental**.  It is not guaranteed to be general or stable,
and may produce incorrect results on some models.  Use with caution and validate
outputs thoroughly.

Converting linear ops to `Conv 1×1` can unlock further layout optimizations
(e.g. via
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter))
on hardware that handles convolutions more efficiently than matrix multiplications.
`Transpose` and `Reshape` ops are inserted on the input/output paths to match
the expected 4-D layout of `Conv`.

This pass is typically applied together with layout optimization.  If the model
contains `Gemm` ops, [`CanonizeGemmPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.CanonizeGemmPass)
**must be applied first** to normalize all `Gemm` variants into a canonical form
(`transA=0`, `transB=0`, static weight) that `LinearToConvPass` can handle:

from qairt.optimizer.onnx.passes.experimental import CanonizeGemmPass, LinearToConvPass
    from qairt.optimizer.onnx.passes import LayoutOptRewriter
    
    # Step 1 (required if model has Gemm ops): canonicalize Gemm variants.
    CanonizeGemmPass().apply(ctx)
    
    # Step 2: convert MatMul / canonicalized Gemm to Conv 1×1.
    LinearToConvPass().apply(ctx)
    
    # Step 3: simplify the inserted Transpose/Reshape sequences.
    LayoutOptRewriter().apply(ctx)
    Copy to clipboard

The Transpose/Reshape should be used on pre/post layout transfering.

**Example 1** — `Gemm` → `Conv 1×1`:

# Input
    Subgraph(input:FLOAT[4, 64]) -> output:FLOAT[4, 32]
    {
        output = Gemm(input, gemm_weight:FLOAT[64,32], gemm_bias:FLOAT[32],
                      alpha=1.5, transA=0, transB=0)
    }
    
    # Output
    Subgraph(input:FLOAT[4, 64]) -> output:FLOAT[4, 32]
    {
        conv_input = Reshape(input, [4,1,1,64])
        conv_input = Transpose(conv_input, perm=[0,3,1,2])
        conv_weight = Reshape(Transpose(gemm_weight * alpha, [1,0]), [32,64,1,1])
        output = Conv(conv_input, conv_weight, gemm_bias, kernel_shape=[1,1])
    }
    Copy to clipboard

**Example 2** — `MatMul + Add` → `Conv 1×1`:

# Input
    Subgraph(input:FLOAT[3,2,4,64]) -> output:FLOAT[3,2,4,32]
    {
        tmp    = MatMul(input, matmul_weight:FLOAT[64,32])
        output = Add(tmp, add_bias:FLOAT[32])
    }
    
    # Output
    Subgraph(input:FLOAT[3,2,4,64]) -> output:FLOAT[3,2,4,32]
    {
        conv_input  = Transpose(input, perm=[0,3,1,2])
        conv_weight = Reshape(Transpose(matmul_weight, [1,0]), [32,64,1,1])
        tmp         = Conv(conv_input, conv_weight, add_bias, kernel_shape=[1,1])
        output      = Transpose(tmp, perm=[0,2,3,1])
    }
    Copy to clipboard

Note

**Gemm prerequisite:** `Gemm` has many variants (different `transA` /
`transB` / `alpha` combinations).  `LinearToConvPass` only processes
**canonicalized** `Gemm` nodes — those with `transA=0`, `transB=0`,
and a static weight tensor.  If your model contains non-canonical `Gemm`
ops, you **must** run
[`CanonizeGemmPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.CanonizeGemmPass) first:

CanonizeGemmPass().apply(ctx)   # normalize Gemm variants
    LinearToConvPass().apply(ctx)   # then convert to Conv 1×1
    Copy to clipboard

Note

**Large batch dimensions:** When input[0] has a large leading dimension (e.g.
`[16384, 256]`), the inserted `Reshape` produces a `[16384, 256, 1, 1]`
tensor that may be suboptimal on device.  Apply
[`LayoutOptRewriter`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.LayoutOptRewriter)
after this pass to eliminate such inefficiencies.

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

LinearToConvPass.Config fields

| Field | Type | Default | Description |
| --- | --- | --- | --- |
| `op_types` | `tuple[str, ...]` | `("MatMul", "Gemm")` | Op types to convert to `Conv 1×1`. |

#### [CanonizeGemmPass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id40)

- *class* qairt.optimizer.onnx.passes.experimental.CanonizeGemmPass(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePredicatePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass)

**Experimental** — normalize all `Gemm` op variants into a canonical form.

Warning

This pass is **experimental**.  It is not guaranteed to be general or stable,
and may produce incorrect results on some models.  Use with caution and validate
outputs thoroughly.

`Gemm` has many variants: inputs may be in either order (static/dynamic),
and `transA` / `transB` may be 0 or 1.  [`LinearToConvPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.LinearToConvPass) only handles
the canonical form: **dynamic input at position 0, static weight at position 1,transA=0 , transB=0**.  This pass rewrites every `Gemm` node into that
form by folding transposes into the static weight constant and inserting
`Transpose` nodes around dynamic inputs as needed.

Note

This pass must be applied **before**
[`LinearToConvPass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.experimental.LinearToConvPass) if the
model contains `Gemm` ops.  `LinearToConvPass` silently skips any `Gemm`
node that is not already in canonical form.

**Example** — `Gemm(transB=1)` → canonical `Gemm(transB=0)`:

# Input: weight is static but transposed at runtime (transB=1)
    Subgraph(input:FLOAT[4, 64]) -> output:FLOAT[4, 32]
    {
        output = Gemm(input, weight:FLOAT[32, 64], bias:FLOAT[32],
                      alpha=1.5, transA=0, transB=1)
    }
    
    # Output: weight is pre-transposed into the initializer; transB=0
    Subgraph(input:FLOAT[4, 64]) -> output:FLOAT[4, 32]
    {
        transposed_weight:FLOAT[64, 32]  # = Transpose(weight, perm=[1,0])
        output = Gemm(input, transposed_weight, bias:FLOAT[32],
                      alpha=1.5, transA=0, transB=0)
    }
    Copy to clipboard

- apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

### [Base Classes](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id41)

Inherit from one of these classes to write a custom optimization pass.  See
[ONNX Optimizer Examples](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-examples.html#qairt-optimizer-examples) for worked examples.

#### [BasePass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id42)

- *class* qairt.optimizer.onnx.passes.BasePass(*config: Optional[PassConfig] = None*)

    - Bases: `ABC`

Abstract base class for all stateless ONNX graph rewriting passes.

Every optimization pass inherits from `BasePass` and implements [`apply()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass.apply).
Passes must be **stateless** with respect to the graph: they receive a
`GraphContext` and modify it in-place, but must
not retain references to graph objects across calls.

The [`apply()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass.apply) method returns the number of modifications made, which lets
callers check whether a pass had any effect and chain passes in loops.

To write a custom pass, subclass either [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass) (for passes that need
full control over traversal) or [`BasePredicatePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass) (for the common
match→rewrite pattern over individual nodes).

Example:

from qairt.optimizer.onnx.passes.base import BasePass
    from qairt.optimizer.onnx.graph import GraphContext
    
    class MyCustomPass(BasePass):
        def apply(self, ctx: GraphContext) -> int:
            count = 0
            for node in list(ctx.graph_ir):
                # ... transform node ...
                count += 1
            return count
    Copy to clipboard

- *abstract* apply(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*) → int

    - Apply the rewriter on the model context.

- Parameters

    - **ctx** – The model context containing the ir.Graph and metadata

Returns: The number of nodes that were rewritten

- mark\_value\_as\_copy(*graph: Graph*, *copy\_from: Value*, *value: Value*)

    - Record the value is copied from copy\_from,
which means they should have
- same numerical value
- shape
- dtype
- same encodings
- same safetensors if they have
- same updatable attribute if they have

- Parameters

    - - **graph** – ir.Graph instance
- **copy\_from** – the tensor copied from
- **value** – the tensor copied to

- mark\_value\_as\_slice(*graph: Graph*, *slice\_from: Value*, *value: Value*, *axis*, *start*, *end*, *batch\_slice\_id*, *head\_slice\_id*)

    - Record that *value* is a contiguous slice of *slice\_from* along *axis*.

Shape and dtype of *value* are inferred automatically.  Encoding metadata
is derived by slicing the source tensor’s encodings, so per-head encodings
are correctly propagated after MHA→SHA splitting.

- Parameters

    - - **graph** – The ONNX IR graph.
- **slice\_from** – The source tensor being sliced.
- **value** – The destination tensor (the slice result).
- **axis** – (int) The axis along which the slice is taken.
- **start** – (int) Start index of the slice along *axis* (inclusive).
- **end** – (int) End index of the slice along *axis* (exclusive).
- **batch\_slice\_id** – (int) Identifier for the batch dimension slice group.
Used internally to correlate slices that belong to the same batch.
Pass `0` if not applicable.
- **head\_slice\_id** – (int) Identifier for the attention-head slice group.
Each individual head should have a unique `head_slice_id` so that
per-head encoding information can be tracked correctly.  Pass `0`
if not slicing along a head axis.

#### [BasePredicatePass](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#id43)

- *class* qairt.optimizer.onnx.passes.BasePredicatePass(*config: Optional[PassConfig] = None*)

    - Bases: [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)

Base class for match-then-rewrite graph passes.

[`BasePredicatePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass) traverses every node in the graph and, for each node,
calls [`match()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.match) to decide whether to rewrite it, and [`rewrite()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.rewrite) to perform
the transformation.  This covers the vast majority of per-node optimizations.

Subclass this (instead of [`BasePass`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass)) when your pass:

- Operates on individual nodes rather than the whole graph at once.
- Can express its applicability check as a predicate on a single node.

**Traversal order** — set the class attribute `TRAVERSAL_ORDER` to `"top_down"`
(default) or `"bottom_up"` to control the iteration direction over the topologically
sorted node list.

**Lifecycle hooks** — override any of [`pre_rewrite_hook()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.pre_rewrite_hook),
[`post_rewrite_hook()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.post_rewrite_hook), [`pre_each_rewrite_hook()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.pre_each_rewrite_hook), or
[`post_each_rewrite_hook()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.post_each_rewrite_hook) for setup/teardown logic.

Example — a simple custom pass that removes `Identity` nodes:

from qairt.optimizer.onnx.passes.base import BasePredicatePass, MatchInfoProtocol
    from qairt.optimizer.onnx.graph import GraphContext
    import onnx_ir as ir
    
    class RemoveIdentity(BasePredicatePass):
        def match(self, graph: ir.Graph, node: ir.Node) -> bool:
            return node.op_type == "Identity"
    
        def rewrite(self, graph: ir.Graph, node: ir.Node,
                    match_info: MatchInfoProtocol | None = None) -> bool:
            src = node.inputs[0]
            dst = node.outputs[0]
            dst.replace_all_uses_with(src)
            graph.remove(node, safe=True)
            return True
    Copy to clipboard

If you need to pass information from [`match()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.match) to [`rewrite()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.rewrite) without
recomputing it, return a `MatchInfoProtocol` subclass from [`match()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.match)
and receive it as *match\_info* in [`rewrite()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.rewrite).

Note

Passes should take care of encodings stored in tensor `meta["extra_info"]`
when creating or replacing tensors.  Use [`mark_value_as_copy()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass.mark_value_as_copy)
or [`mark_value_as_slice()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePass.mark_value_as_slice) to propagate encoding metadata
correctly.

- *abstract* match(*graph: Graph*, *node: Node*) → bool | qairt.optimizer.onnx.passes.base.rewriter.MatchInfoProtocol

    - Return `True` (or a `MatchInfoProtocol` instance) if the node should be rewritten.

- Parameters

    - - **graph** – The ONNX IR graph being traversed.
- **node** – The current node under consideration.

- Returns

    - `False` / `None` to skip, `True` to rewrite without extra info, or a
`MatchInfoProtocol` instance to rewrite and pass structured data to
[`rewrite()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.rewrite).

- post\_each\_rewrite\_hook(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*, *rewrite\_success: bool*)

    - Hook to execute after each rewrite is finished

- post\_rewrite\_hook(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*)

    - Hook to execute after all possible rewrite is finished

- pre\_each\_rewrite\_hook(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*)

    - Hook to execute before each rewrite

- pre\_rewrite\_hook(*ctx: [GraphContext](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.GraphContext)*)

    - Hook to execute before any rewrite is started

- *abstract* rewrite(*graph: Graph*, *node: Node*, *match\_info: Optional[MatchInfoProtocol] = None*) → bool

    - Perform the rewrite on *node* and return `True` if the graph was modified.

- Parameters

    - - **graph** – The ONNX IR graph being traversed.
- **node** – The matched node to rewrite.
- **match\_info** – Extra data returned by [`match()`](https://docs.qualcomm.com/doc/80-87189-2/topic/qairt-optimizer-passes-classes.html#qairt.optimizer.onnx.passes.BasePredicatePass.match), or `None`.

- Returns

    - `True` if the graph was modified, `False` otherwise.

Last Published: Jun 19, 2026

[Previous Topic
adapt\_moe()](https://docs.qualcomm.com/bundle/publicresource/80-87189-2/topics/qairt-optimizer-passes-api.md) [Next Topic
Pipeline (Experimental)](https://docs.qualcomm.com/bundle/publicresource/80-87189-2/topics/qairt-pipeline.md)