Source code for ual_adapter.core.air
"""
Architecture-Agnostic Intermediate Representation (AIR) Format
This module handles the conversion of model-specific LoRA weights to a portable
format that can be transferred across different architectures.
"""
import json
import os
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
import torch
import safetensors.torch
from loguru import logger
[docs]
@dataclass
class AIRMetadata:
"""Metadata for AIR format adapters."""
version: str = "1.0.0"
source_model: str = ""
source_architecture: str = ""
source_dimensions: Dict[str, int] = None
adapter_rank: int = 16
adapter_alpha: float = 16.0
training_config: Dict[str, Any] = None
domain: str = ""
description: str = ""
created_at: str = ""
def __post_init__(self):
if self.source_dimensions is None:
self.source_dimensions = {}
if self.training_config is None:
self.training_config = {}
[docs]
class AIRFormat:
"""
Handles conversion between model-specific LoRA weights and portable AIR format.
The AIR format uses semantic role-based naming instead of model-specific
parameter names, enabling cross-architecture transfer.
"""
# Semantic role definitions
ATTENTION_ROLES = {
"attention_query": ["q_proj", "query", "q_lin", "Wqkv.q", "c_attn"],
"attention_key": ["k_proj", "key", "k_lin", "Wqkv.k"],
"attention_value": ["v_proj", "value", "v_lin", "Wqkv.v"],
"attention_output": ["o_proj", "out_proj", "dense"],
}
MLP_ROLES = {
"mlp_up": ["up_proj", "w1", "c_fc", "intermediate.dense", "mlp.fc1"],
"mlp_down": ["down_proj", "w2", "c_proj", "output.dense", "mlp.fc2"],
"mlp_gate": ["gate_proj", "w3", "gate", "mlp.gate"],
}
[docs]
def __init__(self):
"""Initialize AIR format handler."""
self.role_mappings = {**self.ATTENTION_ROLES, **self.MLP_ROLES}
[docs]
def export_to_air(
self,
lora_weights: Dict[str, torch.Tensor],
metadata: AIRMetadata,
output_path: str
) -> None:
"""
Export LoRA weights to AIR format.
Args:
lora_weights: Dictionary of LoRA weights with model-specific names
metadata: Metadata about the adapter
output_path: Path to save the AIR file
"""
air_weights = {}
unmapped_modules = []
# Convert model-specific names to semantic roles
for param_name, weight in lora_weights.items():
role = self._get_semantic_role(param_name)
if role:
# Extract layer index
layer_idx = self._extract_layer_index(param_name)
air_key = f"layer_{layer_idx}.{role}"
air_weights[air_key] = weight
logger.debug(f"Mapped {param_name} -> {air_key}")
else:
unmapped_modules.append(param_name)
logger.warning(f"Could not map parameter: {param_name}")
if unmapped_modules:
logger.info(f"Unmapped modules: {unmapped_modules}")
# Save weights and metadata
save_data = {
"weights": air_weights,
"metadata": asdict(metadata),
"unmapped": unmapped_modules,
}
# Use safetensors for efficient storage
safetensors.torch.save_file(
air_weights,
f"{output_path}.weights.safetensors",
metadata={"format": "air", "version": metadata.version}
)
# Save metadata separately
with open(f"{output_path}.metadata.json", "w") as f:
json.dump(asdict(metadata), f, indent=2)
logger.info(f"Exported adapter to AIR format: {output_path}")
[docs]
def import_from_air(
self,
air_path: str,
target_model_info: Dict[str, Any]
) -> Tuple[Dict[str, torch.Tensor], AIRMetadata]:
"""
Import AIR format adapter for a target model.
Args:
air_path: Path to the AIR format files
target_model_info: Information about the target model architecture
Returns:
Tuple of (converted weights dict, metadata)
"""
# Load weights
air_weights = safetensors.torch.load_file(f"{air_path}.weights.safetensors")
# Load metadata
with open(f"{air_path}.metadata.json", "r") as f:
metadata_dict = json.load(f)
metadata = AIRMetadata(**metadata_dict)
# Convert to target model naming
target_weights = {}
for air_key, weight in air_weights.items():
target_name = self._air_to_model_specific(
air_key,
target_model_info["architecture"]
)
if target_name:
target_weights[target_name] = weight
logger.debug(f"Mapped {air_key} -> {target_name}")
logger.info(f"Imported {len(target_weights)} weights from AIR format")
return target_weights, metadata
def _get_semantic_role(self, param_name: str) -> Optional[str]:
"""Map a model-specific parameter name to a semantic role."""
param_lower = param_name.lower()
for role, patterns in self.role_mappings.items():
for pattern in patterns:
if pattern in param_lower:
return role
return None
def _extract_layer_index(self, param_name: str) -> int:
"""Extract layer index from parameter name."""
import re
# Common patterns for layer indices
patterns = [
r"layer[s]?\.(\d+)",
r"h\.(\d+)",
r"block[s]?\.(\d+)",
r"transformer\.block\.(\d+)",
]
for pattern in patterns:
match = re.search(pattern, param_name)
if match:
return int(match.group(1))
return 0 # Default to layer 0 if no index found
def _air_to_model_specific(
self,
air_key: str,
target_architecture: str
) -> Optional[str]:
"""Convert AIR role to target model-specific naming."""
# Parse AIR key
parts = air_key.split(".")
if len(parts) != 2:
return None
layer_part, role = parts
layer_idx = int(layer_part.split("_")[1])
# Get architecture-specific naming
binder = self._get_architecture_binder(target_architecture)
if not binder:
return None
return binder.get(role, {}).get("pattern", "").format(layer=layer_idx)
def _get_architecture_binder(self, architecture: str) -> Dict[str, Dict]:
"""Get the naming patterns for a specific architecture."""
# This would load from config files in production
binders = {
"gpt2": {
"attention_query": {"pattern": "transformer.h.{layer}.attn.c_attn"},
"attention_key": {"pattern": "transformer.h.{layer}.attn.c_attn"},
"attention_value": {"pattern": "transformer.h.{layer}.attn.c_attn"},
"mlp_up": {"pattern": "transformer.h.{layer}.mlp.c_fc"},
"mlp_down": {"pattern": "transformer.h.{layer}.mlp.c_proj"},
},
"llama": {
"attention_query": {"pattern": "model.layers.{layer}.self_attn.q_proj"},
"attention_key": {"pattern": "model.layers.{layer}.self_attn.k_proj"},
"attention_value": {"pattern": "model.layers.{layer}.self_attn.v_proj"},
"attention_output": {"pattern": "model.layers.{layer}.self_attn.o_proj"},
"mlp_up": {"pattern": "model.layers.{layer}.mlp.up_proj"},
"mlp_down": {"pattern": "model.layers.{layer}.mlp.down_proj"},
"mlp_gate": {"pattern": "model.layers.{layer}.mlp.gate_proj"},
},
"pythia": {
"attention_query": {"pattern": "gpt_neox.layers.{layer}.attention.query"},
"attention_key": {"pattern": "gpt_neox.layers.{layer}.attention.key"},
"attention_value": {"pattern": "gpt_neox.layers.{layer}.attention.value"},
"mlp_up": {"pattern": "gpt_neox.layers.{layer}.mlp.dense_h_to_4h"},
"mlp_down": {"pattern": "gpt_neox.layers.{layer}.mlp.dense_4h_to_h"},
},
"qwen": {
"attention_query": {"pattern": "transformer.h.{layer}.attn.c_attn"},
"attention_key": {"pattern": "transformer.h.{layer}.attn.c_attn"},
"attention_value": {"pattern": "transformer.h.{layer}.attn.c_attn"},
"attention_output": {"pattern": "transformer.h.{layer}.attn.c_proj"},
"mlp_up": {"pattern": "transformer.h.{layer}.mlp.w1"},
"mlp_down": {"pattern": "transformer.h.{layer}.mlp.w2"},
"mlp_gate": {"pattern": "transformer.h.{layer}.mlp.c_proj"},
},
}
return binders.get(architecture.lower(), {})