Source code for ual_adapter.core.projection

"""
Dimension-Adaptive Projection System

Handles dimension mismatches when transferring adapters between models
with different hidden dimensions using SVD-based projection.
"""

from typing import Tuple, Optional, Dict
import numpy as np
import torch
from loguru import logger


[docs] class DimensionProjector: """ Handles dimension adaptation for LoRA weights using SVD projection. This allows adapters trained on models with one dimension (e.g., 768) to be transferred to models with different dimensions (e.g., 2048). """
[docs] def __init__(self, variance_threshold: float = 0.95): """ Initialize the dimension projector. Args: variance_threshold: Percentage of variance to preserve (0-1) """ self.variance_threshold = variance_threshold
[docs] def project_adapter( self, lora_a: torch.Tensor, lora_b: torch.Tensor, target_in_dim: int, target_out_dim: int, method: str = "svd" ) -> Tuple[torch.Tensor, torch.Tensor]: """ Project LoRA adapter to target dimensions. Args: lora_a: LoRA A matrix (rank x in_features) lora_b: LoRA B matrix (out_features x rank) target_in_dim: Target input dimension target_out_dim: Target output dimension method: Projection method ("svd", "truncate", "interpolate") Returns: Tuple of projected (A, B) matrices """ if method == "svd": return self._svd_projection( lora_a, lora_b, target_in_dim, target_out_dim ) elif method == "truncate": return self._truncate_projection( lora_a, lora_b, target_in_dim, target_out_dim ) elif method == "interpolate": return self._interpolate_projection( lora_a, lora_b, target_in_dim, target_out_dim ) else: raise ValueError(f"Unknown projection method: {method}")
def _svd_projection( self, lora_a: torch.Tensor, lora_b: torch.Tensor, target_in_dim: int, target_out_dim: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ SVD-based projection preserving the most important subspace. """ # Compute the full delta matrix delta = torch.matmul(lora_b, lora_a).cpu().numpy() source_out_dim, source_in_dim = delta.shape # Perform SVD decomposition U, S, Vt = np.linalg.svd(delta, full_matrices=False) # Determine rank to preserve # Always preserve the original LoRA rank to maintain adapter capacity original_rank = lora_a.shape[0] preserved_rank = min(original_rank, len(S)) logger.debug( f"Preserving rank {preserved_rank}/{len(S)} " f"({self.variance_threshold*100:.1f}% variance)" ) # Truncate to preserved rank U_trunc = U[:, :preserved_rank] S_trunc = S[:preserved_rank] Vt_trunc = Vt[:preserved_rank, :] # Adapt to target dimensions U_adapted = self._resize_matrix(U_trunc, (target_out_dim, preserved_rank)) Vt_adapted = self._resize_matrix(Vt_trunc, (preserved_rank, target_in_dim)) # Apply adaptive scaling scale_factor = self._compute_scale_factor( delta.shape, (target_out_dim, target_in_dim) ) S_scaled = S_trunc * scale_factor # Reconstruct as LoRA factors sqrt_S = np.sqrt(S_scaled) lora_b_new = U_adapted * sqrt_S[np.newaxis, :] lora_a_new = sqrt_S[:, np.newaxis] * Vt_adapted # Convert back to torch tensors device = lora_a.device dtype = lora_a.dtype return ( torch.from_numpy(lora_a_new).to(device=device, dtype=dtype), torch.from_numpy(lora_b_new).to(device=device, dtype=dtype) ) def _truncate_projection( self, lora_a: torch.Tensor, lora_b: torch.Tensor, target_in_dim: int, target_out_dim: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Simple truncation or zero-padding projection. """ rank = lora_a.shape[0] # Handle A matrix (rank x in_features) if target_in_dim > lora_a.shape[1]: # Pad with zeros lora_a_new = torch.zeros(rank, target_in_dim, dtype=lora_a.dtype) lora_a_new[:, :lora_a.shape[1]] = lora_a else: # Truncate lora_a_new = lora_a[:, :target_in_dim] # Handle B matrix (out_features x rank) if target_out_dim > lora_b.shape[0]: # Pad with zeros lora_b_new = torch.zeros(target_out_dim, rank, dtype=lora_b.dtype) lora_b_new[:lora_b.shape[0], :] = lora_b else: # Truncate lora_b_new = lora_b[:target_out_dim, :] # Return without scaling to preserve exact values # This maintains the original LoRA weights in the overlapping region return lora_a_new, lora_b_new def _interpolate_projection( self, lora_a: torch.Tensor, lora_b: torch.Tensor, target_in_dim: int, target_out_dim: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Interpolation-based projection for smooth resizing. """ import torch.nn.functional as F # Reshape for interpolation rank = lora_a.shape[0] # Interpolate A matrix lora_a_reshaped = lora_a.unsqueeze(0).unsqueeze(0) # 1x1xRxD lora_a_new = F.interpolate( lora_a_reshaped, size=(rank, target_in_dim), mode='bilinear', align_corners=False ).squeeze() # Interpolate B matrix lora_b_reshaped = lora_b.unsqueeze(0).unsqueeze(0) # 1x1xDxR lora_b_new = F.interpolate( lora_b_reshaped, size=(target_out_dim, rank), mode='bilinear', align_corners=False ).squeeze() return lora_a_new, lora_b_new def _resize_matrix( self, matrix: np.ndarray, target_shape: Tuple[int, int] ) -> np.ndarray: """ Resize a matrix to target shape by truncation or padding. """ current_shape = matrix.shape resized = np.zeros(target_shape) # Copy the overlapping region min_rows = min(current_shape[0], target_shape[0]) min_cols = min(current_shape[1], target_shape[1]) resized[:min_rows, :min_cols] = matrix[:min_rows, :min_cols] return resized def _compute_scale_factor( self, source_shape: Tuple[int, int], target_shape: Tuple[int, int] ) -> float: """ Compute adaptive scaling factor to maintain update magnitude. """ source_dim = np.sqrt(source_shape[0] * source_shape[1]) target_dim = np.sqrt(target_shape[0] * target_shape[1]) # Scale inversely with dimension ratio return np.sqrt(source_dim / target_dim)
[docs] def analyze_projection_quality( self, original_delta: torch.Tensor, projected_a: torch.Tensor, projected_b: torch.Tensor ) -> Dict[str, float]: """ Analyze the quality of a projection. Returns metrics like reconstruction error, rank preservation, etc. """ # Reconstruct delta from projected factors reconstructed = torch.matmul(projected_b, projected_a) # Compute metrics metrics = {} # Reconstruction error (if dimensions match) if original_delta.shape == reconstructed.shape: mse = torch.mean((original_delta - reconstructed) ** 2).item() relative_error = mse / torch.mean(original_delta ** 2).item() metrics["reconstruction_error"] = relative_error # Rank analysis original_rank = torch.linalg.matrix_rank(original_delta).item() reconstructed_rank = torch.linalg.matrix_rank(reconstructed).item() metrics["rank_preservation"] = reconstructed_rank / original_rank # Norm preservation original_norm = torch.norm(original_delta, 'fro').item() reconstructed_norm = torch.norm(reconstructed, 'fro').item() metrics["norm_ratio"] = reconstructed_norm / original_norm # Subspace alignment (principal angle) if original_delta.shape == reconstructed.shape: U_orig, _, _ = torch.svd(original_delta) U_recon, _, _ = torch.svd(reconstructed) # Compute cosine of principal angles cos_angles = torch.diagonal( torch.matmul(U_orig[:, :5].T, U_recon[:, :5]) ) metrics["subspace_alignment"] = torch.mean(cos_angles).item() return metrics