""" Enhanced API security with JWT, OAuth, and Malaysian-specific authentication. """ import json import secrets import time from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Tuple from django.conf import settings from django.contrib.auth import get_user_model from django.core.cache import cache from django.http import HttpRequest, HttpResponse, JsonResponse from django.utils import timezone from django.views.decorators.csrf import csrf_exempt from django.views.decorators.http import require_http_methods from rest_framework import status from rest_framework.decorators import api_view, permission_classes from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework_simplejwt.tokens import RefreshToken, AccessToken from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView from rest_framework_simplejwt.serializers import TokenObtainPairSerializer from jose import jwt, JWTError from jose.exceptions import ExpiredSignatureError, JWTClaimsError import logging import hashlib import hmac from urllib.parse import urlparse import requests from requests.auth import HTTPBasicAuth from .auth import SecurePasswordManager, SecureSessionManager from .middleware import SecurityLoggingMiddleware logger = logging.getLogger(__name__) User = get_user_model() class MalaysianBusinessAuthenticator: """ Malaysian business authentication and authorization system. """ def __init__(self): self.password_manager = SecurePasswordManager() self.session_manager = SecureSessionManager() self.logger = logging.getLogger('security.auth') def authenticate_business(self, request: HttpRequest) -> Optional[Dict]: """ Authenticate Malaysian business with enhanced security. """ try: # Get credentials business_registration = request.data.get('business_registration') password = request.data.get('password') ssm_code = request.data.get('ssm_code') # Validate required fields if not all([business_registration, password, ssm_code]): return None # Verify SSM code format if not self._validate_ssm_code(ssm_code): self.logger.warning(f"Invalid SSM code format: {ssm_code}") return None # Get user by business registration try: user = User.objects.get( business_registration__registration_number=business_registration, is_active=True ) except User.DoesNotExist: self.logger.warning(f"Business not found: {business_registration}") return None # Verify password if not self.password_manager.verify_password(password, user.password): self.logger.warning(f"Invalid password for business: {business_registration}") return None # Verify SSM code if not self._verify_ssm_code(user, ssm_code): self.logger.warning(f"Invalid SSM code for business: {business_registration}") return None # Check business status if not self._check_business_status(user): self.logger.warning(f"Business not active: {business_registration}") return None # Update last login user.last_login = timezone.now() user.save() return { 'user': user, 'business_registration': business_registration, 'ssm_code': ssm_code, } except Exception as e: self.logger.error(f"Authentication error: {e}") return None def _validate_ssm_code(self, ssm_code: str) -> bool: """ Validate SSM code format. """ # SSM code should be 6 digits if len(ssm_code) != 6: return False # Should be numeric if not ssm_code.isdigit(): return False return True def _verify_ssm_code(self, user: User, ssm_code: str) -> bool: """ Verify SSM code against stored hash. """ try: # Get stored SSM code hash stored_hash = getattr(user, 'ssm_code_hash', None) if not stored_hash: return False # Verify hash return self.password_manager.verify_password(ssm_code, stored_hash) except Exception as e: self.logger.error(f"SSM code verification error: {e}") return False def _check_business_status(self, user: User) -> bool: """ Check if business is active and compliant. """ try: # Check if user is active if not user.is_active: return False # Check business registration status business_registration = getattr(user, 'business_registration', None) if not business_registration: return False # Check if registration is active if not business_registration.is_active: return False # Check compliance status if not business_registration.is_compliant: return False return True except Exception as e: self.logger.error(f"Business status check error: {e}") return False class SecureJWTAuthentication: """ Enhanced JWT authentication with Malaysian security considerations. """ def __init__(self): self.secret_key = settings.SECRET_KEY self.algorithm = 'HS256' self.access_token_lifetime = getattr(settings, 'SIMPLE_JWT', {}).get('ACCESS_TOKEN_LIFETIME', timedelta(minutes=15)) self.refresh_token_lifetime = getattr(settings, 'SIMPLE_JWT', {}).get('REFRESH_TOKEN_LIFETIME', timedelta(days=7)) self.logger = logging.getLogger('security.jwt') def generate_token_pair(self, user: User) -> Dict[str, str]: """ Generate secure JWT token pair. """ try: # Generate access token access_token = self._generate_access_token(user) # Generate refresh token refresh_token = self._generate_refresh_token(user) return { 'access_token': access_token, 'refresh_token': refresh_token, 'token_type': 'Bearer', 'expires_in': int(self.access_token_lifetime.total_seconds()), } except Exception as e: self.logger.error(f"Token generation error: {e}") raise def _generate_access_token(self, user: User) -> str: """ Generate access token with enhanced claims. """ now = datetime.utcnow() # Basic claims claims = { 'user_id': user.id, 'username': user.username, 'email': user.email, 'business_registration': getattr(user, 'business_registration', {}).get('registration_number', ''), 'exp': now + self.access_token_lifetime, 'iat': now, 'jti': secrets.token_urlsafe(16), 'type': 'access', } # Add Malaysian-specific claims claims.update({ 'malaysian_business': True, 'business_type': getattr(user, 'business_registration', {}).get('business_type', ''), 'state': getattr(user, 'business_registration', {}).get('state', ''), 'compliance_status': getattr(user, 'business_registration', {}).get('compliance_status', ''), }) # Add security claims claims.update({ 'device_fingerprint': self._get_device_fingerprint(), 'ip_address': self._get_client_ip(), 'session_id': self._generate_session_id(), }) return jwt.encode(claims, self.secret_key, algorithm=self.algorithm) def _generate_refresh_token(self, user: User) -> str: """ Generate refresh token with enhanced claims. """ now = datetime.utcnow() claims = { 'user_id': user.id, 'exp': now + self.refresh_token_lifetime, 'iat': now, 'jti': secrets.token_urlsafe(16), 'type': 'refresh', } return jwt.encode(claims, self.secret_key, algorithm=self.algorithm) def verify_token(self, token: str, token_type: str = 'access') -> Optional[Dict]: """ Verify JWT token with enhanced validation. """ try: # Decode token claims = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) # Validate token type if claims.get('type') != token_type: self.logger.warning(f"Invalid token type: {token_type}") return None # Validate claims if not self._validate_claims(claims): return None # Check if token is revoked if self._is_token_revoked(claims): self.logger.warning("Token revoked") return None return claims except ExpiredSignatureError: self.logger.warning("Token expired") return None except JWTClaimsError as e: self.logger.warning(f"Invalid claims: {e}") return None except JWTError as e: self.logger.warning(f"JWT error: {e}") return None def _validate_claims(self, claims: Dict) -> bool: """ Validate token claims. """ # Check required claims required_claims = ['user_id', 'jti', 'type', 'iat', 'exp'] for claim in required_claims: if claim not in claims: self.logger.warning(f"Missing required claim: {claim}") return False # Validate user exists try: user = User.objects.get(id=claims['user_id']) if not user.is_active: self.logger.warning("User not active") return False except User.DoesNotExist: self.logger.warning("User not found") return False # Validate device fingerprint if present if 'device_fingerprint' in claims: if claims['device_fingerprint'] != self._get_device_fingerprint(): self.logger.warning("Device fingerprint mismatch") return False return True def _is_token_revoked(self, claims: Dict) -> bool: """ Check if token is revoked. """ try: jti = claims.get('jti') if not jti: return True # Check cache for revoked tokens cache_key = f"revoked_token_{jti}" return cache.get(cache_key) is not None except Exception as e: self.logger.error(f"Token revocation check error: {e}") return True def revoke_token(self, claims: Dict): """ Revoke token by adding to revoked list. """ try: jti = claims.get('jti') if not jti: return # Add to revoked tokens cache cache_key = f"revoked_token_{jti}" exp = claims.get('exp') if exp: # Cache until token expires ttl = max(0, exp - int(time.time())) cache.set(cache_key, True, timeout=ttl) except Exception as e: self.logger.error(f"Token revocation error: {e}") def _get_device_fingerprint(self) -> str: """ Generate device fingerprint. """ # Simple device fingerprint based on user agent and IP import hashlib user_agent = "unknown" # Would get from request ip_address = self._get_client_ip() fingerprint = f"{user_agent}:{ip_address}" return hashlib.sha256(fingerprint.encode()).hexdigest() def _get_client_ip(self) -> str: """ Get client IP address. """ # This would get from request headers return "127.0.0.1" # Placeholder def _generate_session_id(self) -> str: """ Generate session ID. """ return secrets.token_urlsafe(32) class OAuth2Provider: """ OAuth2 provider integration for Malaysian services. """ def __init__(self): self.providers = getattr(settings, 'OAUTH2_PROVIDERS', {}) self.logger = logging.getLogger('security.oauth2') def authenticate_with_oauth2(self, provider: str, code: str, redirect_uri: str) -> Optional[Dict]: """ Authenticate using OAuth2 provider. """ try: if provider not in self.providers: self.logger.warning(f"Unknown OAuth2 provider: {provider}") return None provider_config = self.providers[provider] # Exchange code for access token token_response = self._exchange_code_for_token( provider_config, code, redirect_uri ) if not token_response: return None # Get user info user_info = self._get_user_info(provider_config, token_response['access_token']) if not user_info: return None # Find or create user user = self._find_or_create_user(provider, user_info) if not user: return None return { 'user': user, 'provider': provider, 'access_token': token_response['access_token'], 'refresh_token': token_response.get('refresh_token'), 'expires_in': token_response.get('expires_in'), } except Exception as e: self.logger.error(f"OAuth2 authentication error: {e}") return None def _exchange_code_for_token(self, provider_config: Dict, code: str, redirect_uri: str) -> Optional[Dict]: """ Exchange authorization code for access token. """ try: token_url = provider_config['token_url'] client_id = provider_config['client_id'] client_secret = provider_config['client_secret'] data = { 'grant_type': 'authorization_code', 'code': code, 'redirect_uri': redirect_uri, 'client_id': client_id, 'client_secret': client_secret, } response = requests.post(token_url, data=data, timeout=10) if response.status_code != 200: self.logger.warning(f"Token exchange failed: {response.status_code}") return None return response.json() except Exception as e: self.logger.error(f"Token exchange error: {e}") return None def _get_user_info(self, provider_config: Dict, access_token: str) -> Optional[Dict]: """ Get user info from OAuth2 provider. """ try: user_info_url = provider_config['user_info_url'] headers = {'Authorization': f'Bearer {access_token}'} response = requests.get(user_info_url, headers=headers, timeout=10) if response.status_code != 200: self.logger.warning(f"User info request failed: {response.status_code}") return None return response.json() except Exception as e: self.logger.error(f"User info request error: {e}") return None def _find_or_create_user(self, provider: str, user_info: Dict) -> Optional[User]: """ Find or create user from OAuth2 provider info. """ try: # Try to find user by email email = user_info.get('email') if not email: self.logger.warning("No email in user info") return None try: user = User.objects.get(email=email) # Update OAuth2 info if not hasattr(user, 'oauth2_providers'): user.oauth2_providers = {} user.oauth2_providers[provider] = user_info user.save() return user except User.DoesNotExist: # Create new user user = User.objects.create( username=user_info.get('email', ''), email=email, first_name=user_info.get('given_name', ''), last_name=user_info.get('family_name', ''), is_active=True, oauth2_providers={provider: user_info} ) return user except Exception as e: self.logger.error(f"User creation error: {e}") return None class APIRateLimiter: """ Enhanced API rate limiting with Malaysian considerations. """ def __init__(self): self.logger = logging.getLogger('security.ratelimit') # Rate limits self.rate_limits = { 'default': {'requests': 100, 'window': 60}, # 100 requests per minute 'login': {'requests': 5, 'window': 60}, # 5 login attempts per minute 'api': {'requests': 1000, 'window': 60}, # 1000 API requests per minute 'upload': {'requests': 10, 'window': 60}, # 10 uploads per minute 'export': {'requests': 20, 'window': 60}, # 20 exports per minute } # Malaysian business limits self.malaysian_business_limits = { 'sst_calculation': {'requests': 50, 'window': 60}, 'ic_validation': {'requests': 30, 'window': 60}, 'postcode_lookup': {'requests': 100, 'window': 60}, } def check_rate_limit(self, key: str, endpoint: str = 'default') -> bool: """ Check if request is within rate limit. """ try: # Get limit for endpoint if endpoint in self.malaysian_business_limits: limit = self.malaysian_business_limits[endpoint] elif endpoint in self.rate_limits: limit = self.rate_limits[endpoint] else: limit = self.rate_limits['default'] # Check cache cache_key = f"rate_limit_{key}_{endpoint}" current = cache.get(cache_key, 0) if current >= limit['requests']: self.logger.warning(f"Rate limit exceeded for {key} on {endpoint}") return False # Increment counter cache.set(cache_key, current + 1, timeout=limit['window']) return True except Exception as e: self.logger.error(f"Rate limit check error: {e}") return True # Allow on error def get_rate_limit_info(self, key: str, endpoint: str = 'default') -> Dict: """ Get current rate limit information. """ try: # Get limit for endpoint if endpoint in self.malaysian_business_limits: limit = self.malaysian_business_limits[endpoint] elif endpoint in self.rate_limits: limit = self.rate_limits[endpoint] else: limit = self.rate_limits['default'] # Get current count cache_key = f"rate_limit_{key}_{endpoint}" current = cache.get(cache_key, 0) ttl = cache.ttl(cache_key) if ttl < 0: ttl = limit['window'] return { 'limit': limit['requests'], 'remaining': max(0, limit['requests'] - current), 'reset': ttl, 'window': limit['window'], } except Exception as e: self.logger.error(f"Rate limit info error: {e}") return {'limit': 0, 'remaining': 0, 'reset': 0, 'window': 0} # Authentication Views class MalaysianBusinessTokenView(TokenObtainPairView): """ Custom token view for Malaysian business authentication. """ serializer_class = TokenObtainPairSerializer def post(self, request, *args, **kwargs): try: # Initialize authenticator authenticator = MalaysianBusinessAuthenticator() # Authenticate business auth_result = authenticator.authenticate_business(request) if not auth_result: return Response( {'error': 'Invalid credentials'}, status=status.HTTP_401_UNAUTHORIZED ) # Initialize JWT authentication jwt_auth = SecureJWTAuthentication() # Generate tokens tokens = jwt_auth.generate_token_pair(auth_result['user']) # Log successful authentication logger.info(f"Successful authentication for business: {auth_result['business_registration']}") return Response(tokens, status=status.HTTP_200_OK) except Exception as e: logger.error(f"Authentication error: {e}") return Response( {'error': 'Authentication failed'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR ) @api_view(['POST']) @permission_classes([IsAuthenticated]) def refresh_token_view(request): """ Refresh JWT token. """ try: refresh_token = request.data.get('refresh_token') if not refresh_token: return Response( {'error': 'Refresh token required'}, status=status.HTTP_400_BAD_REQUEST ) # Initialize JWT authentication jwt_auth = SecureJWTAuthentication() # Verify refresh token claims = jwt_auth.verify_token(refresh_token, 'refresh') if not claims: return Response( {'error': 'Invalid refresh token'}, status=status.HTTP_401_UNAUTHORIZED ) # Get user try: user = User.objects.get(id=claims['user_id']) except User.DoesNotExist: return Response( {'error': 'User not found'}, status=status.HTTP_404_NOT_FOUND ) # Generate new token pair new_tokens = jwt_auth.generate_token_pair(user) # Revoke old refresh token jwt_auth.revoke_token(claims) return Response(new_tokens, status=status.HTTP_200_OK) except Exception as e: logger.error(f"Token refresh error: {e}") return Response( {'error': 'Token refresh failed'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR ) @api_view(['POST']) @permission_classes([IsAuthenticated]) def revoke_token_view(request): """ Revoke JWT token. """ try: token = request.data.get('token') if not token: return Response( {'error': 'Token required'}, status=status.HTTP_400_BAD_REQUEST ) # Initialize JWT authentication jwt_auth = SecureJWTAuthentication() # Verify token claims = jwt_auth.verify_token(token) if not claims: return Response( {'error': 'Invalid token'}, status=status.HTTP_401_UNAUTHORIZED ) # Revoke token jwt_auth.revoke_token(claims) return Response({'message': 'Token revoked'}, status=status.HTTP_200_OK) except Exception as e: logger.error(f"Token revocation error: {e}") return Response( {'error': 'Token revocation failed'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR ) @api_view(['POST']) def oauth2_login_view(request): """ OAuth2 login view. """ try: provider = request.data.get('provider') code = request.data.get('code') redirect_uri = request.data.get('redirect_uri') if not all([provider, code, redirect_uri]): return Response( {'error': 'Missing required fields'}, status=status.HTTP_400_BAD_REQUEST ) # Initialize OAuth2 provider oauth2_provider = OAuth2Provider() # Authenticate with OAuth2 auth_result = oauth2_provider.authenticate_with_oauth2(provider, code, redirect_uri) if not auth_result: return Response( {'error': 'OAuth2 authentication failed'}, status=status.HTTP_401_UNAUTHORIZED ) # Initialize JWT authentication jwt_auth = SecureJWTAuthentication() # Generate tokens tokens = jwt_auth.generate_token_pair(auth_result['user']) # Log successful authentication logger.info(f"Successful OAuth2 authentication for user: {auth_result['user'].username}") return Response(tokens, status=status.HTTP_200_OK) except Exception as e: logger.error(f"OAuth2 login error: {e}") return Response( {'error': 'OAuth2 authentication failed'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR ) # Middleware for API security class APISecurityMiddleware: """ Middleware for API security with rate limiting and authentication. """ def __init__(self, get_response): self.get_response = get_response self.rate_limiter = APIRateLimiter() self.jwt_auth = SecureJWTAuthentication() self.logger = logging.getLogger('security.api') def __call__(self, request): # Skip for certain endpoints if self._should_skip(request): return self.get_response(request) # Check rate limit if not self._check_rate_limit(request): return JsonResponse( {'error': 'Rate limit exceeded'}, status=status.HTTP_429_TOO_MANY_REQUESTS ) # Check authentication if self._requires_auth(request): auth_result = self._check_authentication(request) if not auth_result: return JsonResponse( {'error': 'Authentication required'}, status=status.HTTP_401_UNAUTHORIZED ) # Add user to request request.user = auth_result['user'] request.auth_claims = auth_result['claims'] return self.get_response(request) def _should_skip(self, request) -> bool: """ Determine if security checks should be skipped. """ skip_paths = [ '/health/', '/metrics/', '/static/', '/media/', '/api/v1/docs/', '/api/v1/oauth2/', ] return any(request.path.startswith(path) for path in skip_paths) def _check_rate_limit(self, request) -> bool: """ Check rate limit for the request. """ try: # Get rate limit key if hasattr(request, 'user') and request.user.is_authenticated: key = f"user_{request.user.id}" else: key = f"ip_{self._get_client_ip(request)}" # Determine endpoint type endpoint = 'default' if 'login' in request.path: endpoint = 'login' elif 'api' in request.path: endpoint = 'api' elif 'upload' in request.path: endpoint = 'upload' elif 'export' in request.path: endpoint = 'export' elif 'sst' in request.path: endpoint = 'sst_calculation' elif 'ic' in request.path: endpoint = 'ic_validation' elif 'postcode' in request.path: endpoint = 'postcode_lookup' return self.rate_limiter.check_rate_limit(key, endpoint) except Exception as e: self.logger.error(f"Rate limit check error: {e}") return True # Allow on error def _requires_auth(self, request) -> bool: """ Check if request requires authentication. """ auth_paths = [ '/api/v1/', '/business/', '/admin/', ] return any(request.path.startswith(path) for path in auth_paths) def _check_authentication(self, request) -> Optional[Dict]: """ Check request authentication. """ try: # Get Authorization header auth_header = request.META.get('HTTP_AUTHORIZATION') if not auth_header: return None # Parse Bearer token if not auth_header.startswith('Bearer '): return None token = auth_header[7:] # Verify token claims = self.jwt_auth.verify_token(token) if not claims: return None # Get user try: user = User.objects.get(id=claims['user_id']) except User.DoesNotExist: return None return { 'user': user, 'claims': claims, } except Exception as e: self.logger.error(f"Authentication check error: {e}") return None def _get_client_ip(self, request) -> str: """ Get client IP address from request. """ x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') if x_forwarded_for: ip = x_forwarded_for.split(',')[0] else: ip = request.META.get('REMOTE_ADDR') return ip