""" Enhanced authentication and authorization system. Implements Malaysian business requirements and security best practices. """ import re import json import logging import secrets from datetime import datetime, timedelta from typing import Dict, List, Any, Optional, Tuple from django.contrib.auth import get_user_model, authenticate, login, logout from django.contrib.auth.password_validation import validate_password from django.core.exceptions import ValidationError from django.core.cache import cache from django.conf import settings from django.utils import timezone from django.utils.crypto import get_random_string, salted_hmac from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed from django.dispatch import receiver from rest_framework_simplejwt.tokens import RefreshToken from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework.exceptions import AuthenticationFailed import redis logger = logging.getLogger(__name__) User = get_user_model() class MalaysianBusinessAuth: """Malaysian business-specific authentication.""" def __init__(self): self.redis_client = self._get_redis_client() self.business_registration_pattern = re.compile(r'^\d{12}$') self.ic_pattern = re.compile(r'^\d{12}$|^(\d{6}-\d{2}-\d{4})$') def validate_business_registration(self, registration_number: str) -> bool: """Validate Malaysian business registration number.""" # Remove spaces and dashes clean_number = re.sub(r'[\s-]', '', registration_number) # Check format (12 digits for most Malaysian businesses) if not self.business_registration_pattern.match(clean_number): return False # Check against known prefixes (optional enhancement) valid_prefixes = [ '1', # Sole proprietorship '2', # Partnership '3', # Private limited company '4', # Public limited company ] return clean_number[0] in valid_prefixes def validate_ic_number(self, ic_number: str) -> Tuple[bool, Optional[str]]: """Validate Malaysian IC number and return age if valid.""" # Clean IC number clean_ic = re.sub(r'[\s-]', '', ic_number) # Check format if not self.ic_pattern.match(clean_ic): return False, None # Extract birth date (simplified validation) try: if len(clean_ic) == 12: # New format without dashes birth_year = int(clean_ic[:2]) birth_month = int(clean_ic[2:4]) birth_day = int(clean_ic[4:6]) # Determine century (rough estimation) if birth_year <= 30: # Assume 2000s full_year = 2000 + birth_year else: # Assume 1900s full_year = 1900 + birth_year # Validate date birth_date = datetime(full_year, birth_month, birth_day) age = (datetime.now() - birth_date).days // 365 return True, age except (ValueError, IndexError): return False, None return False, None def get_business_tier(self, user: User) -> str: """Get business tier based on user profile.""" if hasattr(user, 'business_profile'): profile = user.business_profile annual_revenue = getattr(profile, 'annual_revenue', 0) employee_count = getattr(profile, 'employee_count', 0) if annual_revenue > 5000000 or employee_count > 100: return 'enterprise' elif annual_revenue > 1000000 or employee_count > 20: return 'professional' else: return 'basic' return 'basic' class SecurePasswordManager: """Enhanced password security manager.""" def __init__(self): self.min_length = getattr(settings, 'PASSWORD_MIN_LENGTH', 12) self.max_age_days = getattr(settings, 'PASSWORD_MAX_AGE_DAYS', 90) self.history_count = getattr(settings, 'PASSWORD_HISTORY_COUNT', 5) self.lockout_threshold = getattr(settings, 'PASSWORD_LOCKOUT_THRESHOLD', 5) self.lockout_duration = getattr(settings, 'PASSWORD_LOCKOUT_DURATION', 15) # minutes def validate_password_strength(self, password: str, user: User) -> List[str]: """Validate password strength with Malaysian considerations.""" errors = [] # Basic Django validation try: validate_password(password, user) except ValidationError as e: errors.extend(e.messages) # Additional strength requirements if len(password) < self.min_length: errors.append(f'Password must be at least {self.min_length} characters long') # Check for common Malaysian passwords malaysian_common_passwords = [ 'malaysia', 'kuala', 'lumpur', 'putrajaya', 'johor', 'selangor', 'penang', 'sabah', 'sarawak', 'melaka', '123456', 'password', 'qwerty', 'abc123' ] if password.lower() in malaysian_common_passwords: errors.append('Password is too common') # Check for personal information if hasattr(user, 'profile'): personal_info = [ getattr(user.profile, 'first_name', '').lower(), getattr(user.profile, 'last_name', '').lower(), getattr(user.profile, 'ic_number', ''), getattr(user.profile, 'business_name', '').lower(), user.username.lower(), user.email.split('@')[0].lower() ] for info in personal_info: if info and info.lower() in password.lower(): errors.append('Password contains personal information') return errors def check_password_history(self, user: User, new_password: str) -> bool: """Check if password has been used before.""" if not hasattr(user, 'password_history'): return True history = json.loads(user.password_history) if user.password_history else [] for old_password_hash in history[-self.history_count:]: if user.check_password(new_password, old_password_hash): return False return True def record_password_change(self, user: User, new_password: str): """Record password change in history.""" if not hasattr(user, 'password_history'): user.password_history = '[]' history = json.loads(user.password_history) history.append(user.password) # Keep only recent history if len(history) > self.history_count: history = history[-self.history_count:] user.password_history = json.dumps(history) user.password_change_date = timezone.now() user.save() def check_password_expiry(self, user: User) -> bool: """Check if password has expired.""" if not hasattr(user, 'password_change_date'): return True expiry_date = user.password_change_date + timedelta(days=self.max_age_days) return timezone.now() > expiry_date def is_account_locked(self, user: User) -> bool: """Check if account is locked due to failed attempts.""" if not hasattr(user, 'failed_login_attempts'): return False return user.failed_login_attempts >= self.lockout_threshold def record_failed_login(self, user: User): """Record failed login attempt.""" user.failed_login_attempts = (user.failed_login_attempts or 0) + 1 if user.failed_login_attempts >= self.lockout_threshold: user.locked_until = timezone.now() + timedelta(minutes=self.lockout_duration) user.save() def reset_failed_logins(self, user: User): """Reset failed login attempts.""" user.failed_login_attempts = 0 user.locked_until = None user.save() class SecureSessionManager: """Enhanced session security.""" def __init__(self): self.redis_client = self._get_redis_client() self.session_timeout = getattr(settings, 'SESSION_TIMEOUT', 3600) # 1 hour self.concurrent_sessions = getattr(settings, 'MAX_CONCURRENT_SESSIONS', 3) def _get_redis_client(self): """Get Redis client for session management.""" try: return redis.from_url(settings.REDIS_URL) except Exception: logger.warning("Redis not available for session management") return None def create_secure_session(self, user: User, request) -> str: """Create secure session with device fingerprinting.""" session_id = secrets.token_urlsafe(32) # Get device fingerprint device_info = self._get_device_fingerprint(request) # Store session data session_data = { 'user_id': user.id, 'created_at': timezone.now().isoformat(), 'last_activity': timezone.now().isoformat(), 'device_info': device_info, 'ip_address': self._get_client_ip(request), 'user_agent': request.META.get('HTTP_USER_AGENT', ''), 'is_active': True } if self.redis_client: self.redis_client.setex( f"session:{session_id}", self.session_timeout, json.dumps(session_data) ) return session_id def validate_session(self, session_id: str, request) -> Optional[User]: """Validate session and return user.""" if not self.redis_client: return None session_data = self.redis_client.get(f"session:{session_id}") if not session_data: return None try: session_info = json.loads(session_data) # Check if session is active if not session_info.get('is_active', True): return None # Update last activity session_info['last_activity'] = timezone.now().isoformat() self.redis_client.setex( f"session:{session_id}", self.session_timeout, json.dumps(session_info) ) # Get user try: user = User.objects.get(id=session_info['user_id'], is_active=True) return user except User.DoesNotExist: return None except (json.JSONDecodeError, KeyError): return None def revoke_session(self, session_id: str): """Revoke specific session.""" if self.redis_client: self.redis_client.delete(f"session:{session_id}") def revoke_all_sessions(self, user: User): """Revoke all sessions for user.""" if not self.redis_client: return # Find all user sessions (simplified - in production use session storage) session_keys = self.redis_client.keys("session:*") for key in session_keys: try: session_data = self.redis_client.get(key) if session_data: session_info = json.loads(session_data) if session_info.get('user_id') == user.id: self.redis_client.delete(key) except (json.JSONDecodeError, KeyError): continue def get_active_sessions(self, user: User) -> List[Dict[str, Any]]: """Get active sessions for user.""" if not self.redis_client: return [] sessions = [] session_keys = self.redis_client.keys("session:*") for key in session_keys: try: session_data = self.redis_client.get(key) if session_data: session_info = json.loads(session_data) if session_info.get('user_id') == user.id and session_info.get('is_active', True): sessions.append({ 'session_id': key.decode('utf-8').split(':')[1], 'created_at': session_info['created_at'], 'last_activity': session_info['last_activity'], 'device_info': session_info['device_info'], 'ip_address': session_info['ip_address'] }) except (json.JSONDecodeError, KeyError): continue return sessions def _get_device_fingerprint(self, request) -> Dict[str, str]: """Get device fingerprint for session.""" user_agent = request.META.get('HTTP_USER_AGENT', '') accept_language = request.META.get('HTTP_ACCEPT_LANGUAGE', '') return { 'user_agent': user_agent, 'accept_language': accept_language, 'browser': self._get_browser_info(user_agent), 'os': self._get_os_info(user_agent), } def _get_browser_info(self, user_agent: str) -> str: """Extract browser information from user agent.""" if 'Chrome' in user_agent: return 'Chrome' elif 'Firefox' in user_agent: return 'Firefox' elif 'Safari' in user_agent and 'Chrome' not in user_agent: return 'Safari' elif 'Edge' in user_agent: return 'Edge' else: return 'Unknown' def _get_os_info(self, user_agent: str) -> str: """Extract OS information from user agent.""" if 'Windows' in user_agent: return 'Windows' elif 'Mac' in user_agent: return 'macOS' elif 'Linux' in user_agent: return 'Linux' elif 'Android' in user_agent: return 'Android' elif 'iOS' in user_agent: return 'iOS' else: return 'Unknown' def _get_client_ip(self, request) -> str: """Get client IP address.""" x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') if x_forwarded_for: return x_forwarded_for.split(',')[0].strip() return request.META.get('REMOTE_ADDR', 'unknown') class SecureJWTAuthentication(JWTAuthentication): """Enhanced JWT authentication with Malaysian compliance.""" def __init__(self): super().__init__() self.password_manager = SecurePasswordManager() self.session_manager = SecureSessionManager() self.malaysian_auth = MalaysianBusinessAuth() def authenticate(self, request): """Authenticate user with enhanced security checks.""" try: # Get JWT token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return None token = auth_header.split(' ')[1] # Validate token validated_token = self.get_validated_token(token) user_id = validated_token['user_id'] # Get user user = User.objects.get(id=user_id) # Enhanced security checks if not self._perform_security_checks(user, request, validated_token): raise AuthenticationFailed('Security validation failed') return user except Exception as e: logger.error(f"Authentication error: {e}") raise AuthenticationFailed('Invalid token') def _perform_security_checks(self, user: User, request, token) -> bool: """Perform enhanced security checks.""" # Check if user is active if not user.is_active: return False # Check if account is locked if self.password_manager.is_account_locked(user): return False # Check password expiry if self.password_manager.check_password_expiry(user): return False # Check token claims if not self._validate_token_claims(token, user, request): return False # Check session concurrency if not self._check_session_concurrency(user, token): return False # Malaysian business validation if hasattr(user, 'business_profile'): if not self.malaysian_auth.validate_business_registration( user.business_profile.registration_number ): return False return True def _validate_token_claims(self, token, user, request) -> bool: """Validate JWT token claims.""" # Check token expiration if timezone.now() > datetime.fromtimestamp(token['exp'], tz=timezone.utc): return False # Check issuer if token.get('iss') != getattr(settings, 'JWT_ISSUER', 'malaysian-sme-platform'): return False # Check audience if token.get('aud') != getattr(settings, 'JWT_AUDIENCE', 'malaysian-sme-users'): return False # Check IP address binding (if enabled) if getattr(settings, 'JWT_BIND_TO_IP', False): token_ip = token.get('ip_address') current_ip = self._get_client_ip(request) if token_ip and token_ip != current_ip: return False # Check device binding (if enabled) if getattr(settings, 'JWT_BIND_TO_DEVICE', False): token_device = token.get('device_fingerprint') current_device = self._get_device_fingerprint(request) if token_device and token_device != current_device: return False return True def _check_session_concurrency(self, user, token) -> bool: """Check session concurrency limits.""" if not getattr(settings, 'ENFORCE_SESSION_CONCURRENCY', False): return True # Get active sessions active_sessions = self.session_manager.get_active_sessions(user) # Check if current session is in active sessions current_session_id = token.get('session_id') if not current_session_id: return True # Remove current session from count active_sessions = [s for s in active_sessions if s['session_id'] != current_session_id] return len(active_sessions) < self.session_manager.concurrent_sessions def _get_client_ip(self, request) -> str: """Get client IP address.""" x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') if x_forwarded_for: return x_forwarded_for.split(',')[0].strip() return request.META.get('REMOTE_ADDR', 'unknown') def _get_device_fingerprint(self, request) -> str: """Get device fingerprint.""" user_agent = request.META.get('HTTP_USER_AGENT', '') accept_language = request.META.get('HTTP_ACCEPT_LANGUAGE', '') return salted_hmac('device_fingerprint', f"{user_agent}{accept_language}").hexdigest() # Signal handlers @receiver(user_logged_in) def user_logged_in_handler(sender, request, user, **kwargs): """Handle successful user login.""" logger.info(f"User {user.username} logged in successfully from {request.META.get('REMOTE_ADDR')}") # Reset failed login attempts password_manager = SecurePasswordManager() password_manager.reset_failed_logins(user) # Create secure session session_manager = SecureSessionManager() session_id = session_manager.create_secure_session(user, request) # Store session ID in request for use in JWT request.session_id = session_id @receiver(user_logged_out) def user_logged_out_handler(sender, request, user, **kwargs): """Handle user logout.""" logger.info(f"User {user.username} logged out") # Revoke session session_manager = SecureSessionManager() if hasattr(request, 'session_id'): session_manager.revoke_session(request.session_id) @receiver(user_login_failed) def user_login_failed_handler(sender, credentials, request, **kwargs): """Handle failed login attempt.""" username = credentials.get('username', 'unknown') logger.warning(f"Failed login attempt for username: {username} from {request.META.get('REMOTE_ADDR')}") # Get user if exists try: user = User.objects.get(username=username) password_manager = SecurePasswordManager() password_manager.record_failed_login(user) except User.DoesNotExist: pass # Global instances malaysian_business_auth = MalaysianBusinessAuth() password_manager = SecurePasswordManager() session_manager = SecureSessionManager() secure_jwt_auth = SecureJWTAuthentication()