"""
Binder Registry
Manages registration and retrieval of architecture-specific binders.
"""
from typing import Dict, Optional, Type
from loguru import logger
from ual_adapter.binders.base import ModelBinder
from ual_adapter.binders.architectures import (
GPT2Binder,
LLaMABinder,
PythiaBinder,
QwenBinder,
MistralBinder,
PhiBinder,
BERTBinder,
T5Binder,
GenericBinder
)
[docs]
class BinderRegistry:
"""
Registry for model binders.
Manages the mapping between architecture names and their
corresponding binder implementations.
"""
[docs]
def __init__(self):
"""Initialize the binder registry."""
self._binders: Dict[str, Type[ModelBinder]] = {}
self._instances: Dict[str, ModelBinder] = {}
# Register default binders
self._register_default_binders()
def _register_default_binders(self) -> None:
"""Register the default set of binders."""
default_binders = [
("gpt2", GPT2Binder),
("gptj", GPT2Binder), # Similar structure
("gpt-j", GPT2Binder),
("llama", LLaMABinder),
("llama2", LLaMABinder),
("llama3", LLaMABinder),
("codellama", LLaMABinder),
("pythia", PythiaBinder),
("gpt-neox", PythiaBinder),
("gptneox", PythiaBinder),
("qwen", QwenBinder),
("qwen2", QwenBinder),
("mistral", MistralBinder),
("mixtral", MistralBinder),
("phi", PhiBinder),
("phi2", PhiBinder),
("phi3", PhiBinder),
("bert", BERTBinder),
("roberta", BERTBinder), # Similar structure
("distilbert", BERTBinder),
("t5", T5Binder),
("t5-base", T5Binder),
("mt5", T5Binder),
("generic", GenericBinder),
("unknown", GenericBinder),
]
for arch_name, binder_class in default_binders:
self.register(arch_name, binder_class)
logger.info(f"Registered {len(self._binders)} default binders")
[docs]
def register(
self,
architecture: str,
binder_class: Type[ModelBinder]
) -> None:
"""
Register a new binder for an architecture.
Args:
architecture: Architecture name (case-insensitive)
binder_class: Binder class to register
"""
arch_lower = architecture.lower()
self._binders[arch_lower] = binder_class
logger.debug(f"Registered binder for '{arch_lower}'")
[docs]
def get_binder(self, architecture: str) -> ModelBinder:
"""
Get a binder instance for the specified architecture.
Args:
architecture: Architecture name
Returns:
Binder instance for the architecture
"""
arch_lower = architecture.lower()
# Check if we have a cached instance
if arch_lower in self._instances:
return self._instances[arch_lower]
# Look for exact match
if arch_lower in self._binders:
binder_class = self._binders[arch_lower]
else:
# Try to find partial match
binder_class = self._find_compatible_binder(arch_lower)
if binder_class:
instance = binder_class()
self._instances[arch_lower] = instance
logger.info(f"Using {binder_class.__name__} for '{architecture}'")
return instance
else:
# Fall back to generic binder
logger.warning(
f"No specific binder found for '{architecture}', "
"using generic binder"
)
instance = GenericBinder()
self._instances[arch_lower] = instance
return instance
def _find_compatible_binder(
self,
architecture: str
) -> Optional[Type[ModelBinder]]:
"""
Find a compatible binder for an architecture.
Args:
architecture: Architecture name to match
Returns:
Compatible binder class or None
"""
# Try substring matching
for registered_arch, binder_class in self._binders.items():
if registered_arch in architecture or architecture in registered_arch:
logger.debug(
f"Found compatible binder '{registered_arch}' "
f"for '{architecture}'"
)
return binder_class
# Try removing version numbers
import re
arch_base = re.sub(r'[-_]\d+', '', architecture)
if arch_base != architecture:
return self._find_compatible_binder(arch_base)
return None
[docs]
def list_supported_architectures(self) -> list:
"""
List all supported architectures.
Returns:
List of architecture names
"""
return sorted(self._binders.keys())
[docs]
def is_architecture_supported(self, architecture: str) -> bool:
"""
Check if an architecture is supported.
Args:
architecture: Architecture name to check
Returns:
True if architecture is supported
"""
arch_lower = architecture.lower()
# Check exact match
if arch_lower in self._binders:
return True
# Check if we can find a compatible binder
return self._find_compatible_binder(arch_lower) is not None