Source code for ual_adapter.core.dispatcher

"""
LoRA Dispatcher with Intelligent Routing

Automatically selects and applies the most suitable LoRA adapter
based on query content using sentence embeddings and classification.
"""

from typing import Dict, List, Optional, Tuple, Any
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
from loguru import logger
import pickle
import json


[docs] class LoRADispatcher: """ Intelligent dispatcher that routes queries to appropriate domain LoRAs. Uses sentence embeddings and a trained classifier to determine which domain-specific adapter to use for a given query. """
[docs] def __init__( self, encoder_model: str = "all-MiniLM-L6-v2", confidence_threshold: float = 0.7, cache_embeddings: bool = True ): """ Initialize the LoRA dispatcher. Args: encoder_model: Sentence transformer model for embeddings confidence_threshold: Minimum confidence for domain selection cache_embeddings: Whether to cache computed embeddings """ self.encoder = SentenceTransformer(encoder_model) self.encoder_model_name = encoder_model # Store for later saving self.confidence_threshold = confidence_threshold self.cache_embeddings = cache_embeddings # Domain registry self.domains: Dict[str, DomainAdapter] = {} self.router = None self.domain_embeddings = {} self.embedding_cache = {} if cache_embeddings else None logger.info(f"Initialized LoRA Dispatcher with encoder: {encoder_model}")
[docs] def register_domain( self, domain_name: str, adapter_weights: Dict[str, torch.Tensor], training_texts: List[str], metadata: Optional[Dict[str, Any]] = None ) -> None: """ Register a domain-specific LoRA adapter. Args: domain_name: Name of the domain adapter_weights: LoRA weights for this domain training_texts: Sample texts from this domain for routing metadata: Optional metadata about the domain """ if len(self.domains) >= 10: logger.warning("Maximum 10 domains recommended for optimal performance") # Create domain adapter domain_adapter = DomainAdapter( name=domain_name, weights=adapter_weights, metadata=metadata or {} ) # Compute and store embeddings for training texts logger.info(f"Computing embeddings for domain '{domain_name}'...") embeddings = self.encoder.encode( training_texts, show_progress_bar=True, batch_size=32 ) self.domains[domain_name] = domain_adapter self.domain_embeddings[domain_name] = embeddings logger.info( f"Registered domain '{domain_name}' with " f"{len(training_texts)} training samples" ) # Retrain router with new domain if len(self.domains) > 1: self._train_router()
def _train_router(self) -> None: """Train the router classifier on domain embeddings.""" logger.info("Training router classifier...") X = [] y = [] # Prepare training data for domain_id, (domain_name, embeddings) in enumerate( self.domain_embeddings.items() ): X.extend(embeddings) y.extend([domain_id] * len(embeddings)) # Train logistic regression classifier self.router = LogisticRegression( max_iter=1000, multi_class="multinomial", solver="lbfgs", random_state=42 ) self.router.fit(X, y) # Evaluate on training data accuracy = self.router.score(X, y) logger.info(f"Router accuracy on training data: {accuracy:.2%}")
[docs] def route_query( self, query: str, return_all_scores: bool = False ) -> Tuple[Optional[str], float, Optional[Dict[str, float]]]: """ Determine which domain adapter to use for a query. Args: query: The input query text return_all_scores: Whether to return scores for all domains Returns: Tuple of (selected_domain, confidence, all_scores_dict) """ # Get embedding for query (always compute/cache for consistency) if self.cache_embeddings and query in self.embedding_cache: query_embedding = self.embedding_cache[query] else: query_embedding = self.encoder.encode([query])[0] if self.cache_embeddings: self.embedding_cache[query] = query_embedding if not self.router or len(self.domains) < 2: # Use first/only domain if no routing needed if self.domains: domain_name = list(self.domains.keys())[0] return domain_name, 1.0, None return None, 0.0, None # Get probabilities from router probabilities = self.router.predict_proba([query_embedding])[0] domain_names = list(self.domains.keys()) # Find best domain best_idx = np.argmax(probabilities) best_confidence = probabilities[best_idx] best_domain = domain_names[best_idx] if best_confidence >= self.confidence_threshold else None # Prepare all scores if requested all_scores = None if return_all_scores: all_scores = { domain_names[i]: float(prob) for i, prob in enumerate(probabilities) } return best_domain, float(best_confidence), all_scores
[docs] def apply_adapter( self, model: torch.nn.Module, query: str, verbose: bool = True ) -> Tuple[torch.nn.Module, str, float]: """ Apply the appropriate adapter to a model based on the query. Args: model: The base model to apply adapter to query: The input query verbose: Whether to log routing decisions Returns: Tuple of (model_with_adapter, selected_domain, confidence) """ # Route query to domain domain, confidence, _ = self.route_query(query) if domain: if verbose: logger.info( f"🎯 Routing to domain '{domain}' " f"(confidence: {confidence:.2%})" ) # Apply domain adapter weights adapter = self.domains[domain] model = self._apply_weights(model, adapter.weights) return model, domain, confidence else: if verbose: logger.info( f"⚠️ No suitable domain found " f"(best confidence: {confidence:.2%} < threshold)" ) return model, None, confidence
def _apply_weights( self, model: torch.nn.Module, adapter_weights: Dict[str, torch.Tensor] ) -> torch.nn.Module: """Apply LoRA weights to model.""" # This is a simplified version - in production would use PEFT for name, weight in adapter_weights.items(): # Find corresponding module in model try: module = model for part in name.split('.'): module = getattr(module, part) # Apply LoRA weight update if hasattr(module, 'weight'): module.weight.data += weight except AttributeError: logger.debug(f"Could not apply weight to {name}") return model
[docs] def analyze_domain_overlap(self) -> Dict[str, Any]: """Analyze overlap between registered domains.""" if len(self.domains) < 2: return {"message": "Need at least 2 domains for overlap analysis"} analysis = { "domain_count": len(self.domains), "pairwise_similarities": {}, "domain_separability": {} } domain_names = list(self.domains.keys()) # Compute pairwise cosine similarities for i, domain1 in enumerate(domain_names): for domain2 in domain_names[i+1:]: embeddings1 = self.domain_embeddings[domain1] embeddings2 = self.domain_embeddings[domain2] # Compute mean embeddings mean1 = np.mean(embeddings1, axis=0) mean2 = np.mean(embeddings2, axis=0) # Cosine similarity similarity = np.dot(mean1, mean2) / ( np.linalg.norm(mean1) * np.linalg.norm(mean2) ) pair_key = f"{domain1}-{domain2}" analysis["pairwise_similarities"][pair_key] = float(similarity) # Compute domain separability (classification confidence) if self.router: for domain_id, domain_name in enumerate(domain_names): embeddings = self.domain_embeddings[domain_name] predictions = self.router.predict_proba(embeddings) # Average confidence for correct domain correct_confidence = np.mean(predictions[:, domain_id]) analysis["domain_separability"][domain_name] = float(correct_confidence) return analysis
[docs] def save(self, path: str) -> None: """Save dispatcher state to disk.""" save_data = { "domains": { name: { "metadata": domain.metadata, "weights_shape": { k: list(v.shape) for k, v in domain.weights.items() } } for name, domain in self.domains.items() }, "domain_embeddings": self.domain_embeddings, "confidence_threshold": self.confidence_threshold, "encoder_model": self.encoder_model_name } # Save router separately if trained if self.router: with open(f"{path}_router.pkl", "wb") as f: pickle.dump(self.router, f) # Save main data with open(f"{path}_dispatcher.json", "w") as f: json.dump(save_data, f, indent=2, default=str) # Save adapter weights for domain_name, domain in self.domains.items(): torch.save( domain.weights, f"{path}_weights_{domain_name}.pt" ) logger.info(f"Saved dispatcher to {path}")
[docs] @classmethod def load(cls, path: str) -> "LoRADispatcher": """Load dispatcher state from disk.""" # Load main data with open(f"{path}_dispatcher.json", "r") as f: save_data = json.load(f) # Create dispatcher dispatcher = cls( encoder_model=save_data.get("encoder_model", "all-MiniLM-L6-v2"), confidence_threshold=save_data["confidence_threshold"] ) # Load domains for domain_name, domain_data in save_data["domains"].items(): weights = torch.load(f"{path}_weights_{domain_name}.pt") domain = DomainAdapter( name=domain_name, weights=weights, metadata=domain_data["metadata"] ) dispatcher.domains[domain_name] = domain # Load embeddings dispatcher.domain_embeddings = { k: np.array(v) for k, v in save_data["domain_embeddings"].items() } # Load router if exists try: with open(f"{path}_router.pkl", "rb") as f: dispatcher.router = pickle.load(f) except FileNotFoundError: logger.warning("No router found, will retrain on first use") logger.info(f"Loaded dispatcher from {path}") return dispatcher
[docs] class DomainAdapter: """Container for domain-specific adapter information."""
[docs] def __init__( self, name: str, weights: Dict[str, torch.Tensor], metadata: Dict[str, Any] ): """ Initialize domain adapter. Args: name: Domain name weights: LoRA weights for this domain metadata: Additional information about the domain """ self.name = name self.weights = weights self.metadata = metadata
[docs] def get_info(self) -> Dict[str, Any]: """Get information about this domain adapter.""" return { "name": self.name, "num_parameters": sum(w.numel() for w in self.weights.values()), "weight_shapes": {k: list(v.shape) for k, v in self.weights.items()}, "metadata": self.metadata }