Source code for ual_adapter.utils.model_utils

"""
Model Analysis Utilities

Tools for analyzing model architectures and detecting LoRA-compatible modules.
"""

from typing import Dict, List, Optional, Any, Set, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from loguru import logger


[docs] class ModelAnalyzer: """ Analyzes model architectures to extract useful information for UAL. """
[docs] def __init__(self, model: nn.Module): """ Initialize model analyzer. Args: model: The model to analyze """ self.model = model self._architecture = None self._module_map = None
[docs] def analyze(self) -> Dict[str, Any]: """ Perform comprehensive model analysis. Returns: Dictionary with model information """ info = { "architecture": self._detect_architecture(), "num_parameters": sum(p.numel() for p in self.model.parameters()), "num_layers": self._count_layers(), "hidden_size": self._detect_hidden_size(), "module_types": self._get_module_types(), "attention_modules": self._find_attention_modules(), "mlp_modules": self._find_mlp_modules(), } # Add model name if available if hasattr(self.model, "name_or_path"): info["model_name"] = self.model.name_or_path elif hasattr(self.model, "config") and hasattr(self.model.config, "_name_or_path"): info["model_name"] = self.model.config._name_or_path return info
[docs] def get_lora_target_modules(self) -> List[str]: """ Auto-detect the best modules for LoRA application. Returns: List of module names/patterns for LoRA targets """ model_type = self._detect_architecture().lower() # Get all module names all_modules = [name for name, _ in self.model.named_modules()] # Architecture-specific patterns architecture_patterns = { "gpt2": ["c_attn", "c_proj", "c_fc"], "gptj": ["q_proj", "v_proj"], "gpt_neox": ["query_key_value", "dense"], "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], "bloom": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], "bert": ["query", "key", "value", "dense"], "roberta": ["query", "key", "value", "dense"], "t5": ["q", "k", "v", "o", "wi", "wo"], "phi": ["Wqkv", "out_proj", "fc1", "fc2"], "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], "qwen": ["c_attn", "c_proj", "w1", "w2"], } # Try to find architecture-specific patterns target_modules = [] for arch, patterns in architecture_patterns.items(): if arch in model_type: for pattern in patterns: # Check if pattern exists in model for module_name in all_modules: if pattern in module_name and '.' not in module_name.replace(pattern, ''): if pattern not in target_modules: target_modules.append(pattern) if target_modules: logger.debug(f"Found architecture-specific patterns for {arch}") break # If no specific patterns found, use generic attention module detection if not target_modules: target_modules = self._generic_lora_detection() # Final fallback if not target_modules: logger.warning("Could not detect specific modules, using Linear layers") target_modules = ["Linear"] return list(set(target_modules)) # Remove duplicates
def _generic_lora_detection(self) -> List[str]: """Generic detection of LoRA-compatible modules.""" target_modules = [] for name, module in self.model.named_modules(): # Look for linear layers in attention/mlp blocks if isinstance(module, nn.Linear): name_lower = name.lower() # Check for attention-related patterns attention_keywords = ["attention", "attn", "query", "key", "value", "proj"] mlp_keywords = ["mlp", "ffn", "feedforward", "fc", "dense"] is_attention = any(keyword in name_lower for keyword in attention_keywords) is_mlp = any(keyword in name_lower for keyword in mlp_keywords) if is_attention or is_mlp: # Extract the final component final_component = name.split('.')[-1] if final_component not in target_modules: target_modules.append(final_component) return target_modules def _detect_architecture(self) -> str: """Detect the model architecture.""" if self._architecture: return self._architecture # Check model class name model_class = self.model.__class__.__name__.lower() # Common architecture mappings architecture_map = { "gpt2": "gpt2", "gptj": "gptj", "gptneox": "gpt_neox", "llama": "llama", "opt": "opt", "bloom": "bloom", "bert": "bert", "roberta": "roberta", "t5": "t5", "phi": "phi", "mistral": "mistral", "qwen": "qwen", "pythia": "pythia", "falcon": "falcon", "mpt": "mpt", } for key, arch in architecture_map.items(): if key in model_class: self._architecture = arch return arch # Check config if available if hasattr(self.model, "config"): config = self.model.config if hasattr(config, "model_type"): self._architecture = config.model_type return config.model_type # Default self._architecture = "unknown" return "unknown" def _detect_hidden_size(self) -> int: """Detect the model's hidden size.""" # Try config first if hasattr(self.model, "config"): config = self.model.config # Common attribute names for hidden size size_attrs = ["hidden_size", "d_model", "n_embd", "dim"] for attr in size_attrs: if hasattr(config, attr): return getattr(config, attr) # Try to infer from linear layers for module in self.model.modules(): if isinstance(module, nn.Linear): # Assume first dimension is often hidden size return module.weight.shape[0] return 768 # Default fallback def _count_layers(self) -> int: """Count the number of transformer layers.""" # Try config first if hasattr(self.model, "config"): config = self.model.config # Common attribute names for layer count layer_attrs = ["num_hidden_layers", "n_layers", "n_layer", "num_layers"] for attr in layer_attrs: if hasattr(config, attr): return getattr(config, attr) # Count by module patterns layer_count = 0 for name in self.model.state_dict().keys(): # Look for layer indices import re patterns = [r"layer\.(\d+)", r"layers\.(\d+)", r"h\.(\d+)", r"block\.(\d+)"] for pattern in patterns: match = re.search(pattern, name) if match: layer_idx = int(match.group(1)) layer_count = max(layer_count, layer_idx + 1) return layer_count if layer_count > 0 else 12 # Default fallback def _get_module_types(self) -> Dict[str, int]: """Get counts of different module types.""" module_types = {} for module in self.model.modules(): module_type = module.__class__.__name__ module_types[module_type] = module_types.get(module_type, 0) + 1 return module_types def _find_attention_modules(self) -> List[str]: """Find attention-related modules.""" attention_modules = [] for name, module in self.model.named_modules(): name_lower = name.lower() if any(keyword in name_lower for keyword in ["attention", "attn", "self_attn"]): if isinstance(module, nn.Linear): attention_modules.append(name) return attention_modules def _find_mlp_modules(self) -> List[str]: """Find MLP/FFN modules.""" mlp_modules = [] for name, module in self.model.named_modules(): name_lower = name.lower() if any(keyword in name_lower for keyword in ["mlp", "ffn", "feedforward", "fc"]): if isinstance(module, nn.Linear): mlp_modules.append(name) return mlp_modules
[docs] def get_module_dimensions(self, module_name: str) -> Optional[Tuple[int, int]]: """ Get input and output dimensions for a specific module. Args: module_name: Name of the module Returns: Tuple of (in_features, out_features) or None """ try: module = self.model for part in module_name.split('.'): module = getattr(module, part) if isinstance(module, nn.Linear): return (module.in_features, module.out_features) elif hasattr(module, "weight"): weight_shape = module.weight.shape if len(weight_shape) == 2: return (weight_shape[1], weight_shape[0]) except AttributeError: pass return None
[docs] def validate_lora_compatibility( self, target_modules: List[str] ) -> Dict[str, bool]: """ Validate if target modules are compatible with LoRA. Args: target_modules: List of module names/patterns to check Returns: Dictionary mapping module patterns to compatibility status """ compatibility = {} for pattern in target_modules: found = False for name, module in self.model.named_modules(): if pattern in name: # Check if module is compatible (Linear layer) if isinstance(module, nn.Linear): found = True break compatibility[pattern] = found return compatibility