project initialization
Some checks failed
System Monitoring / Health Checks (push) Has been cancelled
System Monitoring / Performance Monitoring (push) Has been cancelled
System Monitoring / Database Monitoring (push) Has been cancelled
System Monitoring / Cache Monitoring (push) Has been cancelled
System Monitoring / Log Monitoring (push) Has been cancelled
System Monitoring / Resource Monitoring (push) Has been cancelled
System Monitoring / Uptime Monitoring (push) Has been cancelled
System Monitoring / Backup Monitoring (push) Has been cancelled
System Monitoring / Security Monitoring (push) Has been cancelled
System Monitoring / Monitoring Dashboard (push) Has been cancelled
System Monitoring / Alerting (push) Has been cancelled
Security Scanning / Dependency Scanning (push) Has been cancelled
Security Scanning / Code Security Scanning (push) Has been cancelled
Security Scanning / Secrets Scanning (push) Has been cancelled
Security Scanning / Container Security Scanning (push) Has been cancelled
Security Scanning / Compliance Checking (push) Has been cancelled
Security Scanning / Security Dashboard (push) Has been cancelled
Security Scanning / Security Remediation (push) Has been cancelled
Some checks failed
System Monitoring / Health Checks (push) Has been cancelled
System Monitoring / Performance Monitoring (push) Has been cancelled
System Monitoring / Database Monitoring (push) Has been cancelled
System Monitoring / Cache Monitoring (push) Has been cancelled
System Monitoring / Log Monitoring (push) Has been cancelled
System Monitoring / Resource Monitoring (push) Has been cancelled
System Monitoring / Uptime Monitoring (push) Has been cancelled
System Monitoring / Backup Monitoring (push) Has been cancelled
System Monitoring / Security Monitoring (push) Has been cancelled
System Monitoring / Monitoring Dashboard (push) Has been cancelled
System Monitoring / Alerting (push) Has been cancelled
Security Scanning / Dependency Scanning (push) Has been cancelled
Security Scanning / Code Security Scanning (push) Has been cancelled
Security Scanning / Secrets Scanning (push) Has been cancelled
Security Scanning / Container Security Scanning (push) Has been cancelled
Security Scanning / Compliance Checking (push) Has been cancelled
Security Scanning / Security Dashboard (push) Has been cancelled
Security Scanning / Security Remediation (push) Has been cancelled
This commit is contained in:
926
backend/security/api_security.py
Normal file
926
backend/security/api_security.py
Normal file
@@ -0,0 +1,926 @@
|
||||
"""
|
||||
Enhanced API security with JWT, OAuth, and Malaysian-specific authentication.
|
||||
"""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.utils import timezone
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.views.decorators.http import require_http_methods
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
|
||||
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
|
||||
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
|
||||
from jose import jwt, JWTError
|
||||
from jose.exceptions import ExpiredSignatureError, JWTClaimsError
|
||||
import logging
|
||||
import hashlib
|
||||
import hmac
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from .auth import SecurePasswordManager, SecureSessionManager
|
||||
from .middleware import SecurityLoggingMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class MalaysianBusinessAuthenticator:
|
||||
"""
|
||||
Malaysian business authentication and authorization system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.password_manager = SecurePasswordManager()
|
||||
self.session_manager = SecureSessionManager()
|
||||
self.logger = logging.getLogger('security.auth')
|
||||
|
||||
def authenticate_business(self, request: HttpRequest) -> Optional[Dict]:
|
||||
"""
|
||||
Authenticate Malaysian business with enhanced security.
|
||||
"""
|
||||
try:
|
||||
# Get credentials
|
||||
business_registration = request.data.get('business_registration')
|
||||
password = request.data.get('password')
|
||||
ssm_code = request.data.get('ssm_code')
|
||||
|
||||
# Validate required fields
|
||||
if not all([business_registration, password, ssm_code]):
|
||||
return None
|
||||
|
||||
# Verify SSM code format
|
||||
if not self._validate_ssm_code(ssm_code):
|
||||
self.logger.warning(f"Invalid SSM code format: {ssm_code}")
|
||||
return None
|
||||
|
||||
# Get user by business registration
|
||||
try:
|
||||
user = User.objects.get(
|
||||
business_registration__registration_number=business_registration,
|
||||
is_active=True
|
||||
)
|
||||
except User.DoesNotExist:
|
||||
self.logger.warning(f"Business not found: {business_registration}")
|
||||
return None
|
||||
|
||||
# Verify password
|
||||
if not self.password_manager.verify_password(password, user.password):
|
||||
self.logger.warning(f"Invalid password for business: {business_registration}")
|
||||
return None
|
||||
|
||||
# Verify SSM code
|
||||
if not self._verify_ssm_code(user, ssm_code):
|
||||
self.logger.warning(f"Invalid SSM code for business: {business_registration}")
|
||||
return None
|
||||
|
||||
# Check business status
|
||||
if not self._check_business_status(user):
|
||||
self.logger.warning(f"Business not active: {business_registration}")
|
||||
return None
|
||||
|
||||
# Update last login
|
||||
user.last_login = timezone.now()
|
||||
user.save()
|
||||
|
||||
return {
|
||||
'user': user,
|
||||
'business_registration': business_registration,
|
||||
'ssm_code': ssm_code,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Authentication error: {e}")
|
||||
return None
|
||||
|
||||
def _validate_ssm_code(self, ssm_code: str) -> bool:
|
||||
"""
|
||||
Validate SSM code format.
|
||||
"""
|
||||
# SSM code should be 6 digits
|
||||
if len(ssm_code) != 6:
|
||||
return False
|
||||
|
||||
# Should be numeric
|
||||
if not ssm_code.isdigit():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _verify_ssm_code(self, user: User, ssm_code: str) -> bool:
|
||||
"""
|
||||
Verify SSM code against stored hash.
|
||||
"""
|
||||
try:
|
||||
# Get stored SSM code hash
|
||||
stored_hash = getattr(user, 'ssm_code_hash', None)
|
||||
if not stored_hash:
|
||||
return False
|
||||
|
||||
# Verify hash
|
||||
return self.password_manager.verify_password(ssm_code, stored_hash)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"SSM code verification error: {e}")
|
||||
return False
|
||||
|
||||
def _check_business_status(self, user: User) -> bool:
|
||||
"""
|
||||
Check if business is active and compliant.
|
||||
"""
|
||||
try:
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
return False
|
||||
|
||||
# Check business registration status
|
||||
business_registration = getattr(user, 'business_registration', None)
|
||||
if not business_registration:
|
||||
return False
|
||||
|
||||
# Check if registration is active
|
||||
if not business_registration.is_active:
|
||||
return False
|
||||
|
||||
# Check compliance status
|
||||
if not business_registration.is_compliant:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Business status check error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class SecureJWTAuthentication:
|
||||
"""
|
||||
Enhanced JWT authentication with Malaysian security considerations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.secret_key = settings.SECRET_KEY
|
||||
self.algorithm = 'HS256'
|
||||
self.access_token_lifetime = getattr(settings, 'SIMPLE_JWT', {}).get('ACCESS_TOKEN_LIFETIME', timedelta(minutes=15))
|
||||
self.refresh_token_lifetime = getattr(settings, 'SIMPLE_JWT', {}).get('REFRESH_TOKEN_LIFETIME', timedelta(days=7))
|
||||
self.logger = logging.getLogger('security.jwt')
|
||||
|
||||
def generate_token_pair(self, user: User) -> Dict[str, str]:
|
||||
"""
|
||||
Generate secure JWT token pair.
|
||||
"""
|
||||
try:
|
||||
# Generate access token
|
||||
access_token = self._generate_access_token(user)
|
||||
|
||||
# Generate refresh token
|
||||
refresh_token = self._generate_refresh_token(user)
|
||||
|
||||
return {
|
||||
'access_token': access_token,
|
||||
'refresh_token': refresh_token,
|
||||
'token_type': 'Bearer',
|
||||
'expires_in': int(self.access_token_lifetime.total_seconds()),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token generation error: {e}")
|
||||
raise
|
||||
|
||||
def _generate_access_token(self, user: User) -> str:
|
||||
"""
|
||||
Generate access token with enhanced claims.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Basic claims
|
||||
claims = {
|
||||
'user_id': user.id,
|
||||
'username': user.username,
|
||||
'email': user.email,
|
||||
'business_registration': getattr(user, 'business_registration', {}).get('registration_number', ''),
|
||||
'exp': now + self.access_token_lifetime,
|
||||
'iat': now,
|
||||
'jti': secrets.token_urlsafe(16),
|
||||
'type': 'access',
|
||||
}
|
||||
|
||||
# Add Malaysian-specific claims
|
||||
claims.update({
|
||||
'malaysian_business': True,
|
||||
'business_type': getattr(user, 'business_registration', {}).get('business_type', ''),
|
||||
'state': getattr(user, 'business_registration', {}).get('state', ''),
|
||||
'compliance_status': getattr(user, 'business_registration', {}).get('compliance_status', ''),
|
||||
})
|
||||
|
||||
# Add security claims
|
||||
claims.update({
|
||||
'device_fingerprint': self._get_device_fingerprint(),
|
||||
'ip_address': self._get_client_ip(),
|
||||
'session_id': self._generate_session_id(),
|
||||
})
|
||||
|
||||
return jwt.encode(claims, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def _generate_refresh_token(self, user: User) -> str:
|
||||
"""
|
||||
Generate refresh token with enhanced claims.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
claims = {
|
||||
'user_id': user.id,
|
||||
'exp': now + self.refresh_token_lifetime,
|
||||
'iat': now,
|
||||
'jti': secrets.token_urlsafe(16),
|
||||
'type': 'refresh',
|
||||
}
|
||||
|
||||
return jwt.encode(claims, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str, token_type: str = 'access') -> Optional[Dict]:
|
||||
"""
|
||||
Verify JWT token with enhanced validation.
|
||||
"""
|
||||
try:
|
||||
# Decode token
|
||||
claims = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
|
||||
# Validate token type
|
||||
if claims.get('type') != token_type:
|
||||
self.logger.warning(f"Invalid token type: {token_type}")
|
||||
return None
|
||||
|
||||
# Validate claims
|
||||
if not self._validate_claims(claims):
|
||||
return None
|
||||
|
||||
# Check if token is revoked
|
||||
if self._is_token_revoked(claims):
|
||||
self.logger.warning("Token revoked")
|
||||
return None
|
||||
|
||||
return claims
|
||||
|
||||
except ExpiredSignatureError:
|
||||
self.logger.warning("Token expired")
|
||||
return None
|
||||
except JWTClaimsError as e:
|
||||
self.logger.warning(f"Invalid claims: {e}")
|
||||
return None
|
||||
except JWTError as e:
|
||||
self.logger.warning(f"JWT error: {e}")
|
||||
return None
|
||||
|
||||
def _validate_claims(self, claims: Dict) -> bool:
|
||||
"""
|
||||
Validate token claims.
|
||||
"""
|
||||
# Check required claims
|
||||
required_claims = ['user_id', 'jti', 'type', 'iat', 'exp']
|
||||
for claim in required_claims:
|
||||
if claim not in claims:
|
||||
self.logger.warning(f"Missing required claim: {claim}")
|
||||
return False
|
||||
|
||||
# Validate user exists
|
||||
try:
|
||||
user = User.objects.get(id=claims['user_id'])
|
||||
if not user.is_active:
|
||||
self.logger.warning("User not active")
|
||||
return False
|
||||
except User.DoesNotExist:
|
||||
self.logger.warning("User not found")
|
||||
return False
|
||||
|
||||
# Validate device fingerprint if present
|
||||
if 'device_fingerprint' in claims:
|
||||
if claims['device_fingerprint'] != self._get_device_fingerprint():
|
||||
self.logger.warning("Device fingerprint mismatch")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_token_revoked(self, claims: Dict) -> bool:
|
||||
"""
|
||||
Check if token is revoked.
|
||||
"""
|
||||
try:
|
||||
jti = claims.get('jti')
|
||||
if not jti:
|
||||
return True
|
||||
|
||||
# Check cache for revoked tokens
|
||||
cache_key = f"revoked_token_{jti}"
|
||||
return cache.get(cache_key) is not None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token revocation check error: {e}")
|
||||
return True
|
||||
|
||||
def revoke_token(self, claims: Dict):
|
||||
"""
|
||||
Revoke token by adding to revoked list.
|
||||
"""
|
||||
try:
|
||||
jti = claims.get('jti')
|
||||
if not jti:
|
||||
return
|
||||
|
||||
# Add to revoked tokens cache
|
||||
cache_key = f"revoked_token_{jti}"
|
||||
exp = claims.get('exp')
|
||||
if exp:
|
||||
# Cache until token expires
|
||||
ttl = max(0, exp - int(time.time()))
|
||||
cache.set(cache_key, True, timeout=ttl)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token revocation error: {e}")
|
||||
|
||||
def _get_device_fingerprint(self) -> str:
|
||||
"""
|
||||
Generate device fingerprint.
|
||||
"""
|
||||
# Simple device fingerprint based on user agent and IP
|
||||
import hashlib
|
||||
|
||||
user_agent = "unknown" # Would get from request
|
||||
ip_address = self._get_client_ip()
|
||||
|
||||
fingerprint = f"{user_agent}:{ip_address}"
|
||||
return hashlib.sha256(fingerprint.encode()).hexdigest()
|
||||
|
||||
def _get_client_ip(self) -> str:
|
||||
"""
|
||||
Get client IP address.
|
||||
"""
|
||||
# This would get from request headers
|
||||
return "127.0.0.1" # Placeholder
|
||||
|
||||
def _generate_session_id(self) -> str:
|
||||
"""
|
||||
Generate session ID.
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
class OAuth2Provider:
|
||||
"""
|
||||
OAuth2 provider integration for Malaysian services.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.providers = getattr(settings, 'OAUTH2_PROVIDERS', {})
|
||||
self.logger = logging.getLogger('security.oauth2')
|
||||
|
||||
def authenticate_with_oauth2(self, provider: str, code: str, redirect_uri: str) -> Optional[Dict]:
|
||||
"""
|
||||
Authenticate using OAuth2 provider.
|
||||
"""
|
||||
try:
|
||||
if provider not in self.providers:
|
||||
self.logger.warning(f"Unknown OAuth2 provider: {provider}")
|
||||
return None
|
||||
|
||||
provider_config = self.providers[provider]
|
||||
|
||||
# Exchange code for access token
|
||||
token_response = self._exchange_code_for_token(
|
||||
provider_config, code, redirect_uri
|
||||
)
|
||||
|
||||
if not token_response:
|
||||
return None
|
||||
|
||||
# Get user info
|
||||
user_info = self._get_user_info(provider_config, token_response['access_token'])
|
||||
|
||||
if not user_info:
|
||||
return None
|
||||
|
||||
# Find or create user
|
||||
user = self._find_or_create_user(provider, user_info)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return {
|
||||
'user': user,
|
||||
'provider': provider,
|
||||
'access_token': token_response['access_token'],
|
||||
'refresh_token': token_response.get('refresh_token'),
|
||||
'expires_in': token_response.get('expires_in'),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"OAuth2 authentication error: {e}")
|
||||
return None
|
||||
|
||||
def _exchange_code_for_token(self, provider_config: Dict, code: str, redirect_uri: str) -> Optional[Dict]:
|
||||
"""
|
||||
Exchange authorization code for access token.
|
||||
"""
|
||||
try:
|
||||
token_url = provider_config['token_url']
|
||||
client_id = provider_config['client_id']
|
||||
client_secret = provider_config['client_secret']
|
||||
|
||||
data = {
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code,
|
||||
'redirect_uri': redirect_uri,
|
||||
'client_id': client_id,
|
||||
'client_secret': client_secret,
|
||||
}
|
||||
|
||||
response = requests.post(token_url, data=data, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.logger.warning(f"Token exchange failed: {response.status_code}")
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token exchange error: {e}")
|
||||
return None
|
||||
|
||||
def _get_user_info(self, provider_config: Dict, access_token: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get user info from OAuth2 provider.
|
||||
"""
|
||||
try:
|
||||
user_info_url = provider_config['user_info_url']
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
response = requests.get(user_info_url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.logger.warning(f"User info request failed: {response.status_code}")
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"User info request error: {e}")
|
||||
return None
|
||||
|
||||
def _find_or_create_user(self, provider: str, user_info: Dict) -> Optional[User]:
|
||||
"""
|
||||
Find or create user from OAuth2 provider info.
|
||||
"""
|
||||
try:
|
||||
# Try to find user by email
|
||||
email = user_info.get('email')
|
||||
if not email:
|
||||
self.logger.warning("No email in user info")
|
||||
return None
|
||||
|
||||
try:
|
||||
user = User.objects.get(email=email)
|
||||
# Update OAuth2 info
|
||||
if not hasattr(user, 'oauth2_providers'):
|
||||
user.oauth2_providers = {}
|
||||
user.oauth2_providers[provider] = user_info
|
||||
user.save()
|
||||
return user
|
||||
except User.DoesNotExist:
|
||||
# Create new user
|
||||
user = User.objects.create(
|
||||
username=user_info.get('email', ''),
|
||||
email=email,
|
||||
first_name=user_info.get('given_name', ''),
|
||||
last_name=user_info.get('family_name', ''),
|
||||
is_active=True,
|
||||
oauth2_providers={provider: user_info}
|
||||
)
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"User creation error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class APIRateLimiter:
|
||||
"""
|
||||
Enhanced API rate limiting with Malaysian considerations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger('security.ratelimit')
|
||||
|
||||
# Rate limits
|
||||
self.rate_limits = {
|
||||
'default': {'requests': 100, 'window': 60}, # 100 requests per minute
|
||||
'login': {'requests': 5, 'window': 60}, # 5 login attempts per minute
|
||||
'api': {'requests': 1000, 'window': 60}, # 1000 API requests per minute
|
||||
'upload': {'requests': 10, 'window': 60}, # 10 uploads per minute
|
||||
'export': {'requests': 20, 'window': 60}, # 20 exports per minute
|
||||
}
|
||||
|
||||
# Malaysian business limits
|
||||
self.malaysian_business_limits = {
|
||||
'sst_calculation': {'requests': 50, 'window': 60},
|
||||
'ic_validation': {'requests': 30, 'window': 60},
|
||||
'postcode_lookup': {'requests': 100, 'window': 60},
|
||||
}
|
||||
|
||||
def check_rate_limit(self, key: str, endpoint: str = 'default') -> bool:
|
||||
"""
|
||||
Check if request is within rate limit.
|
||||
"""
|
||||
try:
|
||||
# Get limit for endpoint
|
||||
if endpoint in self.malaysian_business_limits:
|
||||
limit = self.malaysian_business_limits[endpoint]
|
||||
elif endpoint in self.rate_limits:
|
||||
limit = self.rate_limits[endpoint]
|
||||
else:
|
||||
limit = self.rate_limits['default']
|
||||
|
||||
# Check cache
|
||||
cache_key = f"rate_limit_{key}_{endpoint}"
|
||||
current = cache.get(cache_key, 0)
|
||||
|
||||
if current >= limit['requests']:
|
||||
self.logger.warning(f"Rate limit exceeded for {key} on {endpoint}")
|
||||
return False
|
||||
|
||||
# Increment counter
|
||||
cache.set(cache_key, current + 1, timeout=limit['window'])
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Rate limit check error: {e}")
|
||||
return True # Allow on error
|
||||
|
||||
def get_rate_limit_info(self, key: str, endpoint: str = 'default') -> Dict:
|
||||
"""
|
||||
Get current rate limit information.
|
||||
"""
|
||||
try:
|
||||
# Get limit for endpoint
|
||||
if endpoint in self.malaysian_business_limits:
|
||||
limit = self.malaysian_business_limits[endpoint]
|
||||
elif endpoint in self.rate_limits:
|
||||
limit = self.rate_limits[endpoint]
|
||||
else:
|
||||
limit = self.rate_limits['default']
|
||||
|
||||
# Get current count
|
||||
cache_key = f"rate_limit_{key}_{endpoint}"
|
||||
current = cache.get(cache_key, 0)
|
||||
|
||||
ttl = cache.ttl(cache_key)
|
||||
if ttl < 0:
|
||||
ttl = limit['window']
|
||||
|
||||
return {
|
||||
'limit': limit['requests'],
|
||||
'remaining': max(0, limit['requests'] - current),
|
||||
'reset': ttl,
|
||||
'window': limit['window'],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Rate limit info error: {e}")
|
||||
return {'limit': 0, 'remaining': 0, 'reset': 0, 'window': 0}
|
||||
|
||||
|
||||
# Authentication Views
|
||||
class MalaysianBusinessTokenView(TokenObtainPairView):
|
||||
"""
|
||||
Custom token view for Malaysian business authentication.
|
||||
"""
|
||||
|
||||
serializer_class = TokenObtainPairSerializer
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
try:
|
||||
# Initialize authenticator
|
||||
authenticator = MalaysianBusinessAuthenticator()
|
||||
|
||||
# Authenticate business
|
||||
auth_result = authenticator.authenticate_business(request)
|
||||
|
||||
if not auth_result:
|
||||
return Response(
|
||||
{'error': 'Invalid credentials'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Initialize JWT authentication
|
||||
jwt_auth = SecureJWTAuthentication()
|
||||
|
||||
# Generate tokens
|
||||
tokens = jwt_auth.generate_token_pair(auth_result['user'])
|
||||
|
||||
# Log successful authentication
|
||||
logger.info(f"Successful authentication for business: {auth_result['business_registration']}")
|
||||
|
||||
return Response(tokens, status=status.HTTP_200_OK)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
return Response(
|
||||
{'error': 'Authentication failed'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@api_view(['POST'])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def refresh_token_view(request):
|
||||
"""
|
||||
Refresh JWT token.
|
||||
"""
|
||||
try:
|
||||
refresh_token = request.data.get('refresh_token')
|
||||
if not refresh_token:
|
||||
return Response(
|
||||
{'error': 'Refresh token required'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# Initialize JWT authentication
|
||||
jwt_auth = SecureJWTAuthentication()
|
||||
|
||||
# Verify refresh token
|
||||
claims = jwt_auth.verify_token(refresh_token, 'refresh')
|
||||
if not claims:
|
||||
return Response(
|
||||
{'error': 'Invalid refresh token'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Get user
|
||||
try:
|
||||
user = User.objects.get(id=claims['user_id'])
|
||||
except User.DoesNotExist:
|
||||
return Response(
|
||||
{'error': 'User not found'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# Generate new token pair
|
||||
new_tokens = jwt_auth.generate_token_pair(user)
|
||||
|
||||
# Revoke old refresh token
|
||||
jwt_auth.revoke_token(claims)
|
||||
|
||||
return Response(new_tokens, status=status.HTTP_200_OK)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh error: {e}")
|
||||
return Response(
|
||||
{'error': 'Token refresh failed'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@api_view(['POST'])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def revoke_token_view(request):
|
||||
"""
|
||||
Revoke JWT token.
|
||||
"""
|
||||
try:
|
||||
token = request.data.get('token')
|
||||
if not token:
|
||||
return Response(
|
||||
{'error': 'Token required'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# Initialize JWT authentication
|
||||
jwt_auth = SecureJWTAuthentication()
|
||||
|
||||
# Verify token
|
||||
claims = jwt_auth.verify_token(token)
|
||||
if not claims:
|
||||
return Response(
|
||||
{'error': 'Invalid token'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Revoke token
|
||||
jwt_auth.revoke_token(claims)
|
||||
|
||||
return Response({'message': 'Token revoked'}, status=status.HTTP_200_OK)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token revocation error: {e}")
|
||||
return Response(
|
||||
{'error': 'Token revocation failed'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@api_view(['POST'])
|
||||
def oauth2_login_view(request):
|
||||
"""
|
||||
OAuth2 login view.
|
||||
"""
|
||||
try:
|
||||
provider = request.data.get('provider')
|
||||
code = request.data.get('code')
|
||||
redirect_uri = request.data.get('redirect_uri')
|
||||
|
||||
if not all([provider, code, redirect_uri]):
|
||||
return Response(
|
||||
{'error': 'Missing required fields'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# Initialize OAuth2 provider
|
||||
oauth2_provider = OAuth2Provider()
|
||||
|
||||
# Authenticate with OAuth2
|
||||
auth_result = oauth2_provider.authenticate_with_oauth2(provider, code, redirect_uri)
|
||||
|
||||
if not auth_result:
|
||||
return Response(
|
||||
{'error': 'OAuth2 authentication failed'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Initialize JWT authentication
|
||||
jwt_auth = SecureJWTAuthentication()
|
||||
|
||||
# Generate tokens
|
||||
tokens = jwt_auth.generate_token_pair(auth_result['user'])
|
||||
|
||||
# Log successful authentication
|
||||
logger.info(f"Successful OAuth2 authentication for user: {auth_result['user'].username}")
|
||||
|
||||
return Response(tokens, status=status.HTTP_200_OK)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth2 login error: {e}")
|
||||
return Response(
|
||||
{'error': 'OAuth2 authentication failed'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
# Middleware for API security
|
||||
class APISecurityMiddleware:
|
||||
"""
|
||||
Middleware for API security with rate limiting and authentication.
|
||||
"""
|
||||
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
self.rate_limiter = APIRateLimiter()
|
||||
self.jwt_auth = SecureJWTAuthentication()
|
||||
self.logger = logging.getLogger('security.api')
|
||||
|
||||
def __call__(self, request):
|
||||
# Skip for certain endpoints
|
||||
if self._should_skip(request):
|
||||
return self.get_response(request)
|
||||
|
||||
# Check rate limit
|
||||
if not self._check_rate_limit(request):
|
||||
return JsonResponse(
|
||||
{'error': 'Rate limit exceeded'},
|
||||
status=status.HTTP_429_TOO_MANY_REQUESTS
|
||||
)
|
||||
|
||||
# Check authentication
|
||||
if self._requires_auth(request):
|
||||
auth_result = self._check_authentication(request)
|
||||
if not auth_result:
|
||||
return JsonResponse(
|
||||
{'error': 'Authentication required'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Add user to request
|
||||
request.user = auth_result['user']
|
||||
request.auth_claims = auth_result['claims']
|
||||
|
||||
return self.get_response(request)
|
||||
|
||||
def _should_skip(self, request) -> bool:
|
||||
"""
|
||||
Determine if security checks should be skipped.
|
||||
"""
|
||||
skip_paths = [
|
||||
'/health/',
|
||||
'/metrics/',
|
||||
'/static/',
|
||||
'/media/',
|
||||
'/api/v1/docs/',
|
||||
'/api/v1/oauth2/',
|
||||
]
|
||||
|
||||
return any(request.path.startswith(path) for path in skip_paths)
|
||||
|
||||
def _check_rate_limit(self, request) -> bool:
|
||||
"""
|
||||
Check rate limit for the request.
|
||||
"""
|
||||
try:
|
||||
# Get rate limit key
|
||||
if hasattr(request, 'user') and request.user.is_authenticated:
|
||||
key = f"user_{request.user.id}"
|
||||
else:
|
||||
key = f"ip_{self._get_client_ip(request)}"
|
||||
|
||||
# Determine endpoint type
|
||||
endpoint = 'default'
|
||||
if 'login' in request.path:
|
||||
endpoint = 'login'
|
||||
elif 'api' in request.path:
|
||||
endpoint = 'api'
|
||||
elif 'upload' in request.path:
|
||||
endpoint = 'upload'
|
||||
elif 'export' in request.path:
|
||||
endpoint = 'export'
|
||||
elif 'sst' in request.path:
|
||||
endpoint = 'sst_calculation'
|
||||
elif 'ic' in request.path:
|
||||
endpoint = 'ic_validation'
|
||||
elif 'postcode' in request.path:
|
||||
endpoint = 'postcode_lookup'
|
||||
|
||||
return self.rate_limiter.check_rate_limit(key, endpoint)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Rate limit check error: {e}")
|
||||
return True # Allow on error
|
||||
|
||||
def _requires_auth(self, request) -> bool:
|
||||
"""
|
||||
Check if request requires authentication.
|
||||
"""
|
||||
auth_paths = [
|
||||
'/api/v1/',
|
||||
'/business/',
|
||||
'/admin/',
|
||||
]
|
||||
|
||||
return any(request.path.startswith(path) for path in auth_paths)
|
||||
|
||||
def _check_authentication(self, request) -> Optional[Dict]:
|
||||
"""
|
||||
Check request authentication.
|
||||
"""
|
||||
try:
|
||||
# Get Authorization header
|
||||
auth_header = request.META.get('HTTP_AUTHORIZATION')
|
||||
if not auth_header:
|
||||
return None
|
||||
|
||||
# Parse Bearer token
|
||||
if not auth_header.startswith('Bearer '):
|
||||
return None
|
||||
|
||||
token = auth_header[7:]
|
||||
|
||||
# Verify token
|
||||
claims = self.jwt_auth.verify_token(token)
|
||||
if not claims:
|
||||
return None
|
||||
|
||||
# Get user
|
||||
try:
|
||||
user = User.objects.get(id=claims['user_id'])
|
||||
except User.DoesNotExist:
|
||||
return None
|
||||
|
||||
return {
|
||||
'user': user,
|
||||
'claims': claims,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Authentication check error: {e}")
|
||||
return None
|
||||
|
||||
def _get_client_ip(self, request) -> str:
|
||||
"""
|
||||
Get client IP address from request.
|
||||
"""
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
if x_forwarded_for:
|
||||
ip = x_forwarded_for.split(',')[0]
|
||||
else:
|
||||
ip = request.META.get('REMOTE_ADDR')
|
||||
|
||||
return ip
|
||||
560
backend/security/auth.py
Normal file
560
backend/security/auth.py
Normal file
@@ -0,0 +1,560 @@
|
||||
"""
|
||||
Enhanced authentication and authorization system.
|
||||
Implements Malaysian business requirements and security best practices.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from django.contrib.auth import get_user_model, authenticate, login, logout
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from django.utils.crypto import get_random_string, salted_hmac
|
||||
from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
|
||||
from django.dispatch import receiver
|
||||
from rest_framework_simplejwt.tokens import RefreshToken
|
||||
from rest_framework_simplejwt.authentication import JWTAuthentication
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
import redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
User = get_user_model()
|
||||
|
||||
class MalaysianBusinessAuth:
|
||||
"""Malaysian business-specific authentication."""
|
||||
|
||||
def __init__(self):
|
||||
self.redis_client = self._get_redis_client()
|
||||
self.business_registration_pattern = re.compile(r'^\d{12}$')
|
||||
self.ic_pattern = re.compile(r'^\d{12}$|^(\d{6}-\d{2}-\d{4})$')
|
||||
|
||||
def validate_business_registration(self, registration_number: str) -> bool:
|
||||
"""Validate Malaysian business registration number."""
|
||||
# Remove spaces and dashes
|
||||
clean_number = re.sub(r'[\s-]', '', registration_number)
|
||||
|
||||
# Check format (12 digits for most Malaysian businesses)
|
||||
if not self.business_registration_pattern.match(clean_number):
|
||||
return False
|
||||
|
||||
# Check against known prefixes (optional enhancement)
|
||||
valid_prefixes = [
|
||||
'1', # Sole proprietorship
|
||||
'2', # Partnership
|
||||
'3', # Private limited company
|
||||
'4', # Public limited company
|
||||
]
|
||||
|
||||
return clean_number[0] in valid_prefixes
|
||||
|
||||
def validate_ic_number(self, ic_number: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Validate Malaysian IC number and return age if valid."""
|
||||
# Clean IC number
|
||||
clean_ic = re.sub(r'[\s-]', '', ic_number)
|
||||
|
||||
# Check format
|
||||
if not self.ic_pattern.match(clean_ic):
|
||||
return False, None
|
||||
|
||||
# Extract birth date (simplified validation)
|
||||
try:
|
||||
if len(clean_ic) == 12: # New format without dashes
|
||||
birth_year = int(clean_ic[:2])
|
||||
birth_month = int(clean_ic[2:4])
|
||||
birth_day = int(clean_ic[4:6])
|
||||
|
||||
# Determine century (rough estimation)
|
||||
if birth_year <= 30: # Assume 2000s
|
||||
full_year = 2000 + birth_year
|
||||
else: # Assume 1900s
|
||||
full_year = 1900 + birth_year
|
||||
|
||||
# Validate date
|
||||
birth_date = datetime(full_year, birth_month, birth_day)
|
||||
age = (datetime.now() - birth_date).days // 365
|
||||
|
||||
return True, age
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return False, None
|
||||
|
||||
return False, None
|
||||
|
||||
def get_business_tier(self, user: User) -> str:
|
||||
"""Get business tier based on user profile."""
|
||||
if hasattr(user, 'business_profile'):
|
||||
profile = user.business_profile
|
||||
annual_revenue = getattr(profile, 'annual_revenue', 0)
|
||||
employee_count = getattr(profile, 'employee_count', 0)
|
||||
|
||||
if annual_revenue > 5000000 or employee_count > 100:
|
||||
return 'enterprise'
|
||||
elif annual_revenue > 1000000 or employee_count > 20:
|
||||
return 'professional'
|
||||
else:
|
||||
return 'basic'
|
||||
|
||||
return 'basic'
|
||||
|
||||
class SecurePasswordManager:
|
||||
"""Enhanced password security manager."""
|
||||
|
||||
def __init__(self):
|
||||
self.min_length = getattr(settings, 'PASSWORD_MIN_LENGTH', 12)
|
||||
self.max_age_days = getattr(settings, 'PASSWORD_MAX_AGE_DAYS', 90)
|
||||
self.history_count = getattr(settings, 'PASSWORD_HISTORY_COUNT', 5)
|
||||
self.lockout_threshold = getattr(settings, 'PASSWORD_LOCKOUT_THRESHOLD', 5)
|
||||
self.lockout_duration = getattr(settings, 'PASSWORD_LOCKOUT_DURATION', 15) # minutes
|
||||
|
||||
def validate_password_strength(self, password: str, user: User) -> List[str]:
|
||||
"""Validate password strength with Malaysian considerations."""
|
||||
errors = []
|
||||
|
||||
# Basic Django validation
|
||||
try:
|
||||
validate_password(password, user)
|
||||
except ValidationError as e:
|
||||
errors.extend(e.messages)
|
||||
|
||||
# Additional strength requirements
|
||||
if len(password) < self.min_length:
|
||||
errors.append(f'Password must be at least {self.min_length} characters long')
|
||||
|
||||
# Check for common Malaysian passwords
|
||||
malaysian_common_passwords = [
|
||||
'malaysia', 'kuala', 'lumpur', 'putrajaya', 'johor',
|
||||
'selangor', 'penang', 'sabah', 'sarawak', 'melaka',
|
||||
'123456', 'password', 'qwerty', 'abc123'
|
||||
]
|
||||
|
||||
if password.lower() in malaysian_common_passwords:
|
||||
errors.append('Password is too common')
|
||||
|
||||
# Check for personal information
|
||||
if hasattr(user, 'profile'):
|
||||
personal_info = [
|
||||
getattr(user.profile, 'first_name', '').lower(),
|
||||
getattr(user.profile, 'last_name', '').lower(),
|
||||
getattr(user.profile, 'ic_number', ''),
|
||||
getattr(user.profile, 'business_name', '').lower(),
|
||||
user.username.lower(),
|
||||
user.email.split('@')[0].lower()
|
||||
]
|
||||
|
||||
for info in personal_info:
|
||||
if info and info.lower() in password.lower():
|
||||
errors.append('Password contains personal information')
|
||||
|
||||
return errors
|
||||
|
||||
def check_password_history(self, user: User, new_password: str) -> bool:
|
||||
"""Check if password has been used before."""
|
||||
if not hasattr(user, 'password_history'):
|
||||
return True
|
||||
|
||||
history = json.loads(user.password_history) if user.password_history else []
|
||||
|
||||
for old_password_hash in history[-self.history_count:]:
|
||||
if user.check_password(new_password, old_password_hash):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def record_password_change(self, user: User, new_password: str):
|
||||
"""Record password change in history."""
|
||||
if not hasattr(user, 'password_history'):
|
||||
user.password_history = '[]'
|
||||
|
||||
history = json.loads(user.password_history)
|
||||
history.append(user.password)
|
||||
|
||||
# Keep only recent history
|
||||
if len(history) > self.history_count:
|
||||
history = history[-self.history_count:]
|
||||
|
||||
user.password_history = json.dumps(history)
|
||||
user.password_change_date = timezone.now()
|
||||
user.save()
|
||||
|
||||
def check_password_expiry(self, user: User) -> bool:
|
||||
"""Check if password has expired."""
|
||||
if not hasattr(user, 'password_change_date'):
|
||||
return True
|
||||
|
||||
expiry_date = user.password_change_date + timedelta(days=self.max_age_days)
|
||||
return timezone.now() > expiry_date
|
||||
|
||||
def is_account_locked(self, user: User) -> bool:
|
||||
"""Check if account is locked due to failed attempts."""
|
||||
if not hasattr(user, 'failed_login_attempts'):
|
||||
return False
|
||||
|
||||
return user.failed_login_attempts >= self.lockout_threshold
|
||||
|
||||
def record_failed_login(self, user: User):
|
||||
"""Record failed login attempt."""
|
||||
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
|
||||
|
||||
if user.failed_login_attempts >= self.lockout_threshold:
|
||||
user.locked_until = timezone.now() + timedelta(minutes=self.lockout_duration)
|
||||
|
||||
user.save()
|
||||
|
||||
def reset_failed_logins(self, user: User):
|
||||
"""Reset failed login attempts."""
|
||||
user.failed_login_attempts = 0
|
||||
user.locked_until = None
|
||||
user.save()
|
||||
|
||||
class SecureSessionManager:
|
||||
"""Enhanced session security."""
|
||||
|
||||
def __init__(self):
|
||||
self.redis_client = self._get_redis_client()
|
||||
self.session_timeout = getattr(settings, 'SESSION_TIMEOUT', 3600) # 1 hour
|
||||
self.concurrent_sessions = getattr(settings, 'MAX_CONCURRENT_SESSIONS', 3)
|
||||
|
||||
def _get_redis_client(self):
|
||||
"""Get Redis client for session management."""
|
||||
try:
|
||||
return redis.from_url(settings.REDIS_URL)
|
||||
except Exception:
|
||||
logger.warning("Redis not available for session management")
|
||||
return None
|
||||
|
||||
def create_secure_session(self, user: User, request) -> str:
|
||||
"""Create secure session with device fingerprinting."""
|
||||
session_id = secrets.token_urlsafe(32)
|
||||
|
||||
# Get device fingerprint
|
||||
device_info = self._get_device_fingerprint(request)
|
||||
|
||||
# Store session data
|
||||
session_data = {
|
||||
'user_id': user.id,
|
||||
'created_at': timezone.now().isoformat(),
|
||||
'last_activity': timezone.now().isoformat(),
|
||||
'device_info': device_info,
|
||||
'ip_address': self._get_client_ip(request),
|
||||
'user_agent': request.META.get('HTTP_USER_AGENT', ''),
|
||||
'is_active': True
|
||||
}
|
||||
|
||||
if self.redis_client:
|
||||
self.redis_client.setex(
|
||||
f"session:{session_id}",
|
||||
self.session_timeout,
|
||||
json.dumps(session_data)
|
||||
)
|
||||
|
||||
return session_id
|
||||
|
||||
def validate_session(self, session_id: str, request) -> Optional[User]:
|
||||
"""Validate session and return user."""
|
||||
if not self.redis_client:
|
||||
return None
|
||||
|
||||
session_data = self.redis_client.get(f"session:{session_id}")
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
session_info = json.loads(session_data)
|
||||
|
||||
# Check if session is active
|
||||
if not session_info.get('is_active', True):
|
||||
return None
|
||||
|
||||
# Update last activity
|
||||
session_info['last_activity'] = timezone.now().isoformat()
|
||||
self.redis_client.setex(
|
||||
f"session:{session_id}",
|
||||
self.session_timeout,
|
||||
json.dumps(session_info)
|
||||
)
|
||||
|
||||
# Get user
|
||||
try:
|
||||
user = User.objects.get(id=session_info['user_id'], is_active=True)
|
||||
return user
|
||||
except User.DoesNotExist:
|
||||
return None
|
||||
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
return None
|
||||
|
||||
def revoke_session(self, session_id: str):
|
||||
"""Revoke specific session."""
|
||||
if self.redis_client:
|
||||
self.redis_client.delete(f"session:{session_id}")
|
||||
|
||||
def revoke_all_sessions(self, user: User):
|
||||
"""Revoke all sessions for user."""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
# Find all user sessions (simplified - in production use session storage)
|
||||
session_keys = self.redis_client.keys("session:*")
|
||||
for key in session_keys:
|
||||
try:
|
||||
session_data = self.redis_client.get(key)
|
||||
if session_data:
|
||||
session_info = json.loads(session_data)
|
||||
if session_info.get('user_id') == user.id:
|
||||
self.redis_client.delete(key)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
def get_active_sessions(self, user: User) -> List[Dict[str, Any]]:
|
||||
"""Get active sessions for user."""
|
||||
if not self.redis_client:
|
||||
return []
|
||||
|
||||
sessions = []
|
||||
session_keys = self.redis_client.keys("session:*")
|
||||
|
||||
for key in session_keys:
|
||||
try:
|
||||
session_data = self.redis_client.get(key)
|
||||
if session_data:
|
||||
session_info = json.loads(session_data)
|
||||
if session_info.get('user_id') == user.id and session_info.get('is_active', True):
|
||||
sessions.append({
|
||||
'session_id': key.decode('utf-8').split(':')[1],
|
||||
'created_at': session_info['created_at'],
|
||||
'last_activity': session_info['last_activity'],
|
||||
'device_info': session_info['device_info'],
|
||||
'ip_address': session_info['ip_address']
|
||||
})
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
return sessions
|
||||
|
||||
def _get_device_fingerprint(self, request) -> Dict[str, str]:
|
||||
"""Get device fingerprint for session."""
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', '')
|
||||
accept_language = request.META.get('HTTP_ACCEPT_LANGUAGE', '')
|
||||
|
||||
return {
|
||||
'user_agent': user_agent,
|
||||
'accept_language': accept_language,
|
||||
'browser': self._get_browser_info(user_agent),
|
||||
'os': self._get_os_info(user_agent),
|
||||
}
|
||||
|
||||
def _get_browser_info(self, user_agent: str) -> str:
|
||||
"""Extract browser information from user agent."""
|
||||
if 'Chrome' in user_agent:
|
||||
return 'Chrome'
|
||||
elif 'Firefox' in user_agent:
|
||||
return 'Firefox'
|
||||
elif 'Safari' in user_agent and 'Chrome' not in user_agent:
|
||||
return 'Safari'
|
||||
elif 'Edge' in user_agent:
|
||||
return 'Edge'
|
||||
else:
|
||||
return 'Unknown'
|
||||
|
||||
def _get_os_info(self, user_agent: str) -> str:
|
||||
"""Extract OS information from user agent."""
|
||||
if 'Windows' in user_agent:
|
||||
return 'Windows'
|
||||
elif 'Mac' in user_agent:
|
||||
return 'macOS'
|
||||
elif 'Linux' in user_agent:
|
||||
return 'Linux'
|
||||
elif 'Android' in user_agent:
|
||||
return 'Android'
|
||||
elif 'iOS' in user_agent:
|
||||
return 'iOS'
|
||||
else:
|
||||
return 'Unknown'
|
||||
|
||||
def _get_client_ip(self, request) -> str:
|
||||
"""Get client IP address."""
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
if x_forwarded_for:
|
||||
return x_forwarded_for.split(',')[0].strip()
|
||||
return request.META.get('REMOTE_ADDR', 'unknown')
|
||||
|
||||
class SecureJWTAuthentication(JWTAuthentication):
|
||||
"""Enhanced JWT authentication with Malaysian compliance."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.password_manager = SecurePasswordManager()
|
||||
self.session_manager = SecureSessionManager()
|
||||
self.malaysian_auth = MalaysianBusinessAuth()
|
||||
|
||||
def authenticate(self, request):
|
||||
"""Authenticate user with enhanced security checks."""
|
||||
try:
|
||||
# Get JWT token
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
return None
|
||||
|
||||
token = auth_header.split(' ')[1]
|
||||
|
||||
# Validate token
|
||||
validated_token = self.get_validated_token(token)
|
||||
user_id = validated_token['user_id']
|
||||
|
||||
# Get user
|
||||
user = User.objects.get(id=user_id)
|
||||
|
||||
# Enhanced security checks
|
||||
if not self._perform_security_checks(user, request, validated_token):
|
||||
raise AuthenticationFailed('Security validation failed')
|
||||
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
raise AuthenticationFailed('Invalid token')
|
||||
|
||||
def _perform_security_checks(self, user: User, request, token) -> bool:
|
||||
"""Perform enhanced security checks."""
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
return False
|
||||
|
||||
# Check if account is locked
|
||||
if self.password_manager.is_account_locked(user):
|
||||
return False
|
||||
|
||||
# Check password expiry
|
||||
if self.password_manager.check_password_expiry(user):
|
||||
return False
|
||||
|
||||
# Check token claims
|
||||
if not self._validate_token_claims(token, user, request):
|
||||
return False
|
||||
|
||||
# Check session concurrency
|
||||
if not self._check_session_concurrency(user, token):
|
||||
return False
|
||||
|
||||
# Malaysian business validation
|
||||
if hasattr(user, 'business_profile'):
|
||||
if not self.malaysian_auth.validate_business_registration(
|
||||
user.business_profile.registration_number
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_token_claims(self, token, user, request) -> bool:
|
||||
"""Validate JWT token claims."""
|
||||
# Check token expiration
|
||||
if timezone.now() > datetime.fromtimestamp(token['exp'], tz=timezone.utc):
|
||||
return False
|
||||
|
||||
# Check issuer
|
||||
if token.get('iss') != getattr(settings, 'JWT_ISSUER', 'malaysian-sme-platform'):
|
||||
return False
|
||||
|
||||
# Check audience
|
||||
if token.get('aud') != getattr(settings, 'JWT_AUDIENCE', 'malaysian-sme-users'):
|
||||
return False
|
||||
|
||||
# Check IP address binding (if enabled)
|
||||
if getattr(settings, 'JWT_BIND_TO_IP', False):
|
||||
token_ip = token.get('ip_address')
|
||||
current_ip = self._get_client_ip(request)
|
||||
if token_ip and token_ip != current_ip:
|
||||
return False
|
||||
|
||||
# Check device binding (if enabled)
|
||||
if getattr(settings, 'JWT_BIND_TO_DEVICE', False):
|
||||
token_device = token.get('device_fingerprint')
|
||||
current_device = self._get_device_fingerprint(request)
|
||||
if token_device and token_device != current_device:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_session_concurrency(self, user, token) -> bool:
|
||||
"""Check session concurrency limits."""
|
||||
if not getattr(settings, 'ENFORCE_SESSION_CONCURRENCY', False):
|
||||
return True
|
||||
|
||||
# Get active sessions
|
||||
active_sessions = self.session_manager.get_active_sessions(user)
|
||||
|
||||
# Check if current session is in active sessions
|
||||
current_session_id = token.get('session_id')
|
||||
if not current_session_id:
|
||||
return True
|
||||
|
||||
# Remove current session from count
|
||||
active_sessions = [s for s in active_sessions if s['session_id'] != current_session_id]
|
||||
|
||||
return len(active_sessions) < self.session_manager.concurrent_sessions
|
||||
|
||||
def _get_client_ip(self, request) -> str:
|
||||
"""Get client IP address."""
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
if x_forwarded_for:
|
||||
return x_forwarded_for.split(',')[0].strip()
|
||||
return request.META.get('REMOTE_ADDR', 'unknown')
|
||||
|
||||
def _get_device_fingerprint(self, request) -> str:
|
||||
"""Get device fingerprint."""
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', '')
|
||||
accept_language = request.META.get('HTTP_ACCEPT_LANGUAGE', '')
|
||||
return salted_hmac('device_fingerprint', f"{user_agent}{accept_language}").hexdigest()
|
||||
|
||||
# Signal handlers
|
||||
@receiver(user_logged_in)
|
||||
def user_logged_in_handler(sender, request, user, **kwargs):
|
||||
"""Handle successful user login."""
|
||||
logger.info(f"User {user.username} logged in successfully from {request.META.get('REMOTE_ADDR')}")
|
||||
|
||||
# Reset failed login attempts
|
||||
password_manager = SecurePasswordManager()
|
||||
password_manager.reset_failed_logins(user)
|
||||
|
||||
# Create secure session
|
||||
session_manager = SecureSessionManager()
|
||||
session_id = session_manager.create_secure_session(user, request)
|
||||
|
||||
# Store session ID in request for use in JWT
|
||||
request.session_id = session_id
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def user_logged_out_handler(sender, request, user, **kwargs):
|
||||
"""Handle user logout."""
|
||||
logger.info(f"User {user.username} logged out")
|
||||
|
||||
# Revoke session
|
||||
session_manager = SecureSessionManager()
|
||||
if hasattr(request, 'session_id'):
|
||||
session_manager.revoke_session(request.session_id)
|
||||
|
||||
@receiver(user_login_failed)
|
||||
def user_login_failed_handler(sender, credentials, request, **kwargs):
|
||||
"""Handle failed login attempt."""
|
||||
username = credentials.get('username', 'unknown')
|
||||
logger.warning(f"Failed login attempt for username: {username} from {request.META.get('REMOTE_ADDR')}")
|
||||
|
||||
# Get user if exists
|
||||
try:
|
||||
user = User.objects.get(username=username)
|
||||
password_manager = SecurePasswordManager()
|
||||
password_manager.record_failed_login(user)
|
||||
except User.DoesNotExist:
|
||||
pass
|
||||
|
||||
# Global instances
|
||||
malaysian_business_auth = MalaysianBusinessAuth()
|
||||
password_manager = SecurePasswordManager()
|
||||
session_manager = SecureSessionManager()
|
||||
secure_jwt_auth = SecureJWTAuthentication()
|
||||
504
backend/security/headers.py
Normal file
504
backend/security/headers.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""
|
||||
Security headers and Content Security Policy (CSP) management for enhanced security.
|
||||
"""
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import HttpResponse
|
||||
from django.middleware.security import SecurityMiddleware as DjangoSecurityMiddleware
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from django.core.exceptions import MiddlewareNotUsed
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(MiddlewareMixin):
|
||||
"""
|
||||
Enhanced security middleware with comprehensive security headers and CSP policies.
|
||||
"""
|
||||
|
||||
def __init__(self, get_response):
|
||||
super().__init__(get_response)
|
||||
|
||||
# Initialize CSP configuration
|
||||
self.csp_config = getattr(settings, 'CSP_CONFIG', self._get_default_csp_config())
|
||||
|
||||
# Initialize security headers
|
||||
self.security_headers = getattr(settings, 'SECURITY_HEADERS', self._get_default_security_headers())
|
||||
|
||||
# Initialize allowed domains
|
||||
self.allowed_domains = set(getattr(settings, 'ALLOWED_DOMAINS', []))
|
||||
|
||||
# Initialize nonce generator
|
||||
self.nonce_generator = CSPNonceGenerator()
|
||||
|
||||
def process_response(self, request, response):
|
||||
"""
|
||||
Add security headers to the response.
|
||||
"""
|
||||
# Skip for static files and media
|
||||
if self._should_skip_headers(request):
|
||||
return response
|
||||
|
||||
# Add security headers
|
||||
for header, value in self.security_headers.items():
|
||||
response[header] = value
|
||||
|
||||
# Add CSP header
|
||||
csp_value = self._generate_csp_header(request)
|
||||
if csp_value:
|
||||
response['Content-Security-Policy'] = csp_value
|
||||
|
||||
# Add Report-Only CSP in development
|
||||
if settings.DEBUG and getattr(settings, 'CSP_REPORT_ONLY', False):
|
||||
response['Content-Security-Policy-Report-Only'] = csp_value
|
||||
|
||||
# Add feature policy
|
||||
feature_policy = self._generate_feature_policy()
|
||||
if feature_policy:
|
||||
response['Feature-Policy'] = feature_policy
|
||||
|
||||
# Add permissions policy
|
||||
permissions_policy = self._generate_permissions_policy()
|
||||
if permissions_policy:
|
||||
response['Permissions-Policy'] = permissions_policy
|
||||
|
||||
# Add HSTS header in production
|
||||
if not settings.DEBUG and getattr(settings, 'SECURE_HSTS_SECONDS', 0):
|
||||
response['Strict-Transport-Security'] = self._generate_hsts_header()
|
||||
|
||||
return response
|
||||
|
||||
def _should_skip_headers(self, request) -> bool:
|
||||
"""
|
||||
Determine if security headers should be skipped for this request.
|
||||
"""
|
||||
# Skip for static files
|
||||
if request.path.startswith(settings.STATIC_URL):
|
||||
return True
|
||||
|
||||
# Skip for media files
|
||||
if request.path.startswith(settings.MEDIA_URL):
|
||||
return True
|
||||
|
||||
# Skip for health checks
|
||||
if request.path.startswith('/health/'):
|
||||
return True
|
||||
|
||||
# Skip for metrics
|
||||
if request.path.startswith('/metrics/'):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_default_csp_config(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get default CSP configuration.
|
||||
"""
|
||||
return {
|
||||
'default-src': ["'self'"],
|
||||
'script-src': ["'self'", "'unsafe-inline'", "'unsafe-eval'"],
|
||||
'style-src': ["'self'", "'unsafe-inline'"],
|
||||
'img-src': ["'self'", "data:", "https:"],
|
||||
'font-src': ["'self'", "data:"],
|
||||
'connect-src': ["'self'"],
|
||||
'media-src': ["'self'"],
|
||||
'object-src': ["'none'"],
|
||||
'frame-src': ["'self'"],
|
||||
'frame-ancestors': ["'self'"],
|
||||
'form-action': ["'self'"],
|
||||
'base-uri': ["'self'"],
|
||||
'manifest-src': ["'self'"],
|
||||
'worker-src': ["'self'"],
|
||||
'child-src': ["'self'"],
|
||||
'prefetch-src': ["'self'"],
|
||||
'require-trusted-types-for': ["'script'"],
|
||||
'trusted-types': ["'default'"],
|
||||
'upgrade-insecure-requests': [],
|
||||
'block-all-mixed-content': [],
|
||||
'report-uri': ['/csp-report-endpoint/'],
|
||||
'report-to': ['csp-endpoint'],
|
||||
}
|
||||
|
||||
def _get_default_security_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get default security headers.
|
||||
"""
|
||||
return {
|
||||
'X-Content-Type-Options': 'nosniff',
|
||||
'X-Frame-Options': 'DENY',
|
||||
'X-XSS-Protection': '1; mode=block',
|
||||
'Referrer-Policy': 'strict-origin-when-cross-origin',
|
||||
'X-Permitted-Cross-Domain-Policies': 'none',
|
||||
'Clear-Site-Data': '"cache", "cookies", "storage"',
|
||||
'Cross-Origin-Opener-Policy': 'same-origin',
|
||||
'Cross-Origin-Embedder-Policy': 'require-corp',
|
||||
'Cross-Origin-Resource-Policy': 'same-origin',
|
||||
}
|
||||
|
||||
def _generate_csp_header(self, request) -> str:
|
||||
"""
|
||||
Generate CSP header value based on configuration.
|
||||
"""
|
||||
directives = []
|
||||
|
||||
for directive, sources in self.csp_config.items():
|
||||
if sources:
|
||||
# Add nonce for script and style directives
|
||||
if directive in ['script-src', 'style-src'] and "'unsafe-inline'" in sources:
|
||||
nonce = self.nonce_generator.get_nonce()
|
||||
sources.remove("'unsafe-inline'")
|
||||
sources.append(f"'nonce-{nonce}'")
|
||||
|
||||
# Join sources
|
||||
source_list = ' '.join(sources)
|
||||
directives.append(f"{directive} {source_list}")
|
||||
|
||||
return '; '.join(directives)
|
||||
|
||||
def _generate_feature_policy(self) -> str:
|
||||
"""
|
||||
Generate Feature Policy header.
|
||||
"""
|
||||
policies = [
|
||||
'camera none',
|
||||
'microphone none',
|
||||
'geolocation none',
|
||||
'payment none',
|
||||
'usb none',
|
||||
'magnetometer none',
|
||||
'gyroscope none',
|
||||
'accelerometer none',
|
||||
'fullscreen self',
|
||||
'document-domain none',
|
||||
'sync-xhr self',
|
||||
'usb none',
|
||||
]
|
||||
|
||||
return ', '.join(policies)
|
||||
|
||||
def _generate_permissions_policy(self) -> str:
|
||||
"""
|
||||
Generate Permissions Policy header.
|
||||
"""
|
||||
policies = [
|
||||
'camera=()',
|
||||
'microphone=()',
|
||||
'geolocation=()',
|
||||
'payment=()',
|
||||
'usb=()',
|
||||
'magnetometer=()',
|
||||
'gyroscope=()',
|
||||
'accelerometer=()',
|
||||
'fullscreen=(self)',
|
||||
'document-domain=()',
|
||||
'sync-xhr=(self)',
|
||||
'usb=()',
|
||||
]
|
||||
|
||||
return ', '.join(policies)
|
||||
|
||||
def _generate_hsts_header(self) -> str:
|
||||
"""
|
||||
Generate HSTS header.
|
||||
"""
|
||||
max_age = getattr(settings, 'SECURE_HSTS_SECONDS', 31536000)
|
||||
include_subdomains = getattr(settings, 'SECURE_HSTS_INCLUDE_SUBDOMAINS', True)
|
||||
preload = getattr(settings, 'SECURE_HSTS_PRELOAD', False)
|
||||
|
||||
header = f'max-age={max_age}'
|
||||
|
||||
if include_subdomains:
|
||||
header += '; includeSubDomains'
|
||||
|
||||
if preload:
|
||||
header += '; preload'
|
||||
|
||||
return header
|
||||
|
||||
|
||||
class CSPNonceGenerator:
|
||||
"""
|
||||
Generator for CSP nonces.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._nonces = set()
|
||||
self._max_nonces = 1000 # Prevent memory leaks
|
||||
|
||||
def get_nonce(self) -> str:
|
||||
"""
|
||||
Generate a new nonce.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
# Clean up old nonces if we have too many
|
||||
if len(self._nonces) > self._max_nonces:
|
||||
self._nonces.clear()
|
||||
|
||||
# Generate new nonce
|
||||
nonce = secrets.token_urlsafe(16)
|
||||
self._nonces.add(nonce)
|
||||
|
||||
return nonce
|
||||
|
||||
def is_valid_nonce(self, nonce: str) -> bool:
|
||||
"""
|
||||
Check if a nonce is valid.
|
||||
"""
|
||||
return nonce in self._nonces
|
||||
|
||||
def clear_nonces(self):
|
||||
"""
|
||||
Clear all nonces.
|
||||
"""
|
||||
self._nonces.clear()
|
||||
|
||||
|
||||
class CSPReportHandler:
|
||||
"""
|
||||
Handler for CSP violation reports.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger('security.csp')
|
||||
|
||||
def handle_report(self, report_data: Dict):
|
||||
"""
|
||||
Handle CSP violation report.
|
||||
"""
|
||||
try:
|
||||
# Log the violation
|
||||
self.logger.warning(
|
||||
f"CSP Violation: {report_data.get('document-uri')} - "
|
||||
f"{report_data.get('violated-directive')} - "
|
||||
f"{report_data.get('blocked-uri')}"
|
||||
)
|
||||
|
||||
# Send to monitoring system
|
||||
self._send_to_monitoring(report_data)
|
||||
|
||||
# Store for analysis
|
||||
self._store_violation(report_data)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error handling CSP report: {e}")
|
||||
|
||||
def _send_to_monitoring(self, report_data: Dict):
|
||||
"""
|
||||
Send violation report to monitoring system.
|
||||
"""
|
||||
try:
|
||||
from monitoring.alerts import alert_manager
|
||||
from monitoring.alerts import Alert, AlertSeverity, AlertCategory
|
||||
|
||||
alert = Alert(
|
||||
title="CSP Violation",
|
||||
description=f"CSP violation detected: {report_data.get('violated-directive')}",
|
||||
severity=AlertSeverity.WARNING,
|
||||
category=AlertCategory.SECURITY,
|
||||
metadata={
|
||||
'document_uri': report_data.get('document-uri'),
|
||||
'violated_directive': report_data.get('violated-directive'),
|
||||
'blocked_uri': report_data.get('blocked-uri'),
|
||||
'line_number': report_data.get('line-number'),
|
||||
'column_number': report_data.get('column-number'),
|
||||
}
|
||||
)
|
||||
|
||||
alert_manager.trigger_alert(alert)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending CSP report to monitoring: {e}")
|
||||
|
||||
def _store_violation(self, report_data: Dict):
|
||||
"""
|
||||
Store violation for analysis.
|
||||
"""
|
||||
try:
|
||||
from django.core.cache import cache
|
||||
|
||||
# Store recent violations
|
||||
cache_key = f"csp_violations_{report_data.get('document-uri', 'unknown')}"
|
||||
violations = cache.get(cache_key, [])
|
||||
|
||||
violations.append({
|
||||
'timestamp': report_data.get('timestamp'),
|
||||
'violated_directive': report_data.get('violated-directive'),
|
||||
'blocked_uri': report_data.get('blocked-uri'),
|
||||
'line_number': report_data.get('line-number'),
|
||||
'column_number': report_data.get('column-number'),
|
||||
})
|
||||
|
||||
# Keep only last 100 violations
|
||||
if len(violations) > 100:
|
||||
violations = violations[-100:]
|
||||
|
||||
cache.set(cache_key, violations, timeout=86400) # 24 hours
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error storing CSP violation: {e}")
|
||||
|
||||
|
||||
class SecurityHeaderValidator:
|
||||
"""
|
||||
Validator for security headers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger('security.headers')
|
||||
|
||||
def validate_headers(self, response: HttpResponse) -> Dict[str, bool]:
|
||||
"""
|
||||
Validate security headers in response.
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Validate CSP header
|
||||
results['csp'] = self._validate_csp_header(response)
|
||||
|
||||
# Validate HSTS header
|
||||
results['hsts'] = self._validate_hsts_header(response)
|
||||
|
||||
# Validate other security headers
|
||||
results['x_content_type_options'] = self._validate_x_content_type_options(response)
|
||||
results['x_frame_options'] = self._validate_x_frame_options(response)
|
||||
results['x_xss_protection'] = self._validate_x_xss_protection(response)
|
||||
results['referrer_policy'] = self._validate_referrer_policy(response)
|
||||
|
||||
return results
|
||||
|
||||
def _validate_csp_header(self, response: HttpResponse) -> bool:
|
||||
"""
|
||||
Validate CSP header.
|
||||
"""
|
||||
csp_header = response.get('Content-Security-Policy')
|
||||
if not csp_header:
|
||||
self.logger.warning("Missing CSP header")
|
||||
return False
|
||||
|
||||
# Check for required directives
|
||||
required_directives = ['default-src', 'script-src', 'style-src']
|
||||
for directive in required_directives:
|
||||
if f"{directive} " not in csp_header:
|
||||
self.logger.warning(f"Missing required CSP directive: {directive}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_hsts_header(self, response: HttpResponse) -> bool:
|
||||
"""
|
||||
Validate HSTS header.
|
||||
"""
|
||||
hsts_header = response.get('Strict-Transport-Security')
|
||||
if not hsts_header:
|
||||
self.logger.warning("Missing HSTS header")
|
||||
return False
|
||||
|
||||
# Check for max-age
|
||||
if 'max-age=' not in hsts_header:
|
||||
self.logger.warning("HSTS header missing max-age")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_x_content_type_options(self, response: HttpResponse) -> bool:
|
||||
"""
|
||||
Validate X-Content-Type-Options header.
|
||||
"""
|
||||
header = response.get('X-Content-Type-Options')
|
||||
if header != 'nosniff':
|
||||
self.logger.warning("Invalid X-Content-Type-Options header")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_x_frame_options(self, response: HttpResponse) -> bool:
|
||||
"""
|
||||
Validate X-Frame-Options header.
|
||||
"""
|
||||
header = response.get('X-Frame-Options')
|
||||
if header not in ['DENY', 'SAMEORIGIN']:
|
||||
self.logger.warning("Invalid X-Frame-Options header")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_x_xss_protection(self, response: HttpResponse) -> bool:
|
||||
"""
|
||||
Validate X-XSS-Protection header.
|
||||
"""
|
||||
header = response.get('X-XSS-Protection')
|
||||
if header != '1; mode=block':
|
||||
self.logger.warning("Invalid X-XSS-Protection header")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_referrer_policy(self, response: HttpResponse) -> bool:
|
||||
"""
|
||||
Validate Referrer-Policy header.
|
||||
"""
|
||||
header = response.get('Referrer-Policy')
|
||||
valid_policies = [
|
||||
'no-referrer',
|
||||
'no-referrer-when-downgrade',
|
||||
'origin',
|
||||
'origin-when-cross-origin',
|
||||
'same-origin',
|
||||
'strict-origin',
|
||||
'strict-origin-when-cross-origin',
|
||||
'unsafe-url'
|
||||
]
|
||||
|
||||
if header not in valid_policies:
|
||||
self.logger.warning("Invalid Referrer-Policy header")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class SecurityHeaderMiddleware(SecurityHeadersMiddleware):
|
||||
"""
|
||||
Enhanced security middleware with Malaysian-specific security considerations.
|
||||
"""
|
||||
|
||||
def __init__(self, get_response):
|
||||
super().__init__(get_response)
|
||||
|
||||
# Initialize Malaysian-specific security headers
|
||||
self.malaysian_headers = getattr(settings, 'MALAYSIAN_SECURITY_HEADERS', {
|
||||
'X-Malaysian-Data-Protection': 'PDPA-Compliant',
|
||||
'X-Malaysian-Privacy-Policy': '/privacy-policy/',
|
||||
'X-Malaysian-Contact': '/contact/',
|
||||
})
|
||||
|
||||
# Initialize validator
|
||||
self.validator = SecurityHeaderValidator()
|
||||
|
||||
# Initialize report handler
|
||||
self.report_handler = CSPReportHandler()
|
||||
|
||||
def process_response(self, request, response):
|
||||
"""
|
||||
Add security headers with Malaysian-specific considerations.
|
||||
"""
|
||||
# Call parent method
|
||||
response = super().process_response(request, response)
|
||||
|
||||
# Add Malaysian-specific headers
|
||||
for header, value in self.malaysian_headers.items():
|
||||
response[header] = value
|
||||
|
||||
# Validate headers
|
||||
if settings.DEBUG:
|
||||
validation_results = self.validator.validate_headers(response)
|
||||
for header, is_valid in validation_results.items():
|
||||
if not is_valid:
|
||||
self.validator.logger.warning(f"Invalid {header} header")
|
||||
|
||||
return response
|
||||
806
backend/security/middleware.py
Normal file
806
backend/security/middleware.py
Normal file
@@ -0,0 +1,806 @@
|
||||
"""
|
||||
Security middleware for comprehensive protection.
|
||||
Implements Malaysian data protection and security best practices.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.core.cache import cache
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils import timezone
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.middleware.security import SecurityMiddleware as DjangoSecurityMiddleware
|
||||
from django.middleware.clickjacking import XFrameOptionsMiddleware
|
||||
from django.middleware.csrf import CsrfViewMiddleware
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from prometheus_client import Counter, Histogram, Gauge
|
||||
import redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
User = get_user_model()
|
||||
|
||||
# Security metrics
|
||||
SECURITY_EVENTS = Counter(
|
||||
'security_events_total',
|
||||
'Security events',
|
||||
['event_type', 'severity', 'ip_address', 'user_agent', 'tenant']
|
||||
)
|
||||
|
||||
RATE_LIMIT_EVENTS = Counter(
|
||||
'rate_limit_events_total',
|
||||
'Rate limit events',
|
||||
['type', 'ip_address', 'endpoint', 'tenant']
|
||||
)
|
||||
|
||||
MALAYSIAN_DATA_ACCESS = Counter(
|
||||
'malaysian_data_access_total',
|
||||
'Malaysian data access events',
|
||||
['data_type', 'operation', 'user_role', 'tenant']
|
||||
)
|
||||
|
||||
THREAT_DETECTION = Counter(
|
||||
'threat_detection_total',
|
||||
'Threat detection events',
|
||||
['threat_type', 'confidence', 'ip_address', 'tenant']
|
||||
)
|
||||
|
||||
class SecurityHeadersMiddleware(MiddlewareMixin):
|
||||
"""Enhanced security headers middleware."""
|
||||
|
||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||
"""Add comprehensive security headers."""
|
||||
# Security headers
|
||||
response['X-Content-Type-Options'] = 'nosniff'
|
||||
response['X-Frame-Options'] = 'DENY'
|
||||
response['X-XSS-Protection'] = '1; mode=block'
|
||||
response['Referrer-Policy'] = 'strict-origin-when-cross-origin'
|
||||
response['Permissions-Policy'] = self._get_permissions_policy()
|
||||
response['Content-Security-Policy'] = self._get_csp(request)
|
||||
response['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains; preload'
|
||||
|
||||
# Malaysian data protection headers
|
||||
response['X-Malaysian-Data-Protection'] = 'PDPA-Compliant'
|
||||
response['X-Data-Residency'] = 'Malaysia'
|
||||
|
||||
# Remove sensitive headers
|
||||
sensitive_headers = ['Server', 'X-Powered-By', 'X-AspNet-Version']
|
||||
for header in sensitive_headers:
|
||||
if header in response:
|
||||
del response[header]
|
||||
|
||||
return response
|
||||
|
||||
def _get_permissions_policy(self) -> str:
|
||||
"""Get permissions policy."""
|
||||
policies = [
|
||||
'accelerometer=()',
|
||||
'ambient-light-sensor=()',
|
||||
'battery=()',
|
||||
'bluetooth=()',
|
||||
'camera=()',
|
||||
'cross-origin-isolated=()',
|
||||
'display-capture=()',
|
||||
'document-domain=()',
|
||||
'encrypted-media=()',
|
||||
'execution-while-not-rendered=()',
|
||||
'execution-while-out-of-viewport=()',
|
||||
'focus-without-user-activation=()',
|
||||
'fullscreen=()',
|
||||
'geolocation=()',
|
||||
'gyroscope=()',
|
||||
'hid=()',
|
||||
'identity-credentials-get=()',
|
||||
'idle-detection=()',
|
||||
'local-fonts=()',
|
||||
'magnetometer=()',
|
||||
'microphone=()',
|
||||
'midi=()',
|
||||
'otp-credentials=()',
|
||||
'payment=()',
|
||||
'picture-in-picture=()',
|
||||
'publickey-credentials-get=()',
|
||||
'screen-wake-lock=()',
|
||||
'serial=()',
|
||||
'storage-access=()',
|
||||
'usb=()',
|
||||
'web-share=()',
|
||||
'window-management=()',
|
||||
'xr-spatial-tracking=()'
|
||||
]
|
||||
return ', '.join(policies)
|
||||
|
||||
def _get_csp(self, request: HttpRequest) -> str:
|
||||
"""Get Content Security Policy."""
|
||||
# Base CSP
|
||||
csp = [
|
||||
"default-src 'self'",
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdn.jsdelivr.net https://www.google.com https://www.gstatic.com",
|
||||
"style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://fonts.googleapis.com",
|
||||
"img-src 'self' data: https: https://*.malaysian-sme-platform.com",
|
||||
"font-src 'self' https://fonts.gstatic.com https://fonts.googleapis.com",
|
||||
"connect-src 'self' https://api.malaysian-sme-platform.com wss://api.malaysian-sme-platform.com",
|
||||
"frame-ancestors 'none'",
|
||||
"form-action 'self'",
|
||||
"base-uri 'self'",
|
||||
"require-trusted-types-for 'script'",
|
||||
"report-uri /api/security/csp-report/",
|
||||
]
|
||||
|
||||
# Add development-specific policies
|
||||
if settings.DEBUG:
|
||||
csp[1] = csp[1].replace("'unsafe-inline'", "'unsafe-inline' 'unsafe-eval'")
|
||||
csp.append("upgrade-insecure-requests")
|
||||
|
||||
return '; '.join(csp)
|
||||
|
||||
class RateLimitingMiddleware(MiddlewareMixin):
|
||||
"""Advanced rate limiting middleware with Malaysian considerations."""
|
||||
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
self.redis_client = self._get_redis_client()
|
||||
self.limits = self._get_rate_limits()
|
||||
|
||||
def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
|
||||
"""Process request for rate limiting."""
|
||||
if self._should_skip_rate_limiting(request):
|
||||
return None
|
||||
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_id = self._get_user_id(request)
|
||||
endpoint = self._get_endpoint(request)
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
# Check all applicable limits
|
||||
for limit_type, limit_config in self.limits.items():
|
||||
if self._should_apply_limit(request, limit_type):
|
||||
limited = self._check_rate_limit(
|
||||
ip_address, user_id, endpoint, tenant, limit_type, limit_config
|
||||
)
|
||||
|
||||
if limited:
|
||||
return self._create_rate_limit_response(limit_type, limit_config)
|
||||
|
||||
return None
|
||||
|
||||
def _get_redis_client(self):
|
||||
"""Get Redis client for rate limiting."""
|
||||
try:
|
||||
return redis.from_url(settings.REDIS_URL)
|
||||
except Exception:
|
||||
logger.warning("Redis not available for rate limiting")
|
||||
return None
|
||||
|
||||
def _get_rate_limits(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get rate limit configurations."""
|
||||
return {
|
||||
'api': {
|
||||
'requests': 1000,
|
||||
'window': 3600, # 1 hour
|
||||
'scope': 'ip',
|
||||
'message': 'API rate limit exceeded'
|
||||
},
|
||||
'login': {
|
||||
'requests': 5,
|
||||
'window': 300, # 5 minutes
|
||||
'scope': 'ip',
|
||||
'message': 'Too many login attempts'
|
||||
},
|
||||
'malaysian_data': {
|
||||
'requests': 100,
|
||||
'window': 3600, # 1 hour
|
||||
'scope': 'user',
|
||||
'message': 'Malaysian data access rate limit exceeded'
|
||||
},
|
||||
'file_upload': {
|
||||
'requests': 50,
|
||||
'window': 3600, # 1 hour
|
||||
'scope': 'user',
|
||||
'message': 'File upload rate limit exceeded'
|
||||
},
|
||||
'sensitive_operations': {
|
||||
'requests': 10,
|
||||
'window': 3600, # 1 hour
|
||||
'scope': 'user',
|
||||
'message': 'Sensitive operations rate limit exceeded'
|
||||
},
|
||||
}
|
||||
|
||||
def _should_skip_rate_limiting(self, request: HttpRequest) -> bool:
|
||||
"""Check if request should skip rate limiting."""
|
||||
# Skip for health checks and static files
|
||||
skip_paths = ['/health/', '/metrics/', '/static/']
|
||||
if any(request.path.startswith(path) for path in skip_paths):
|
||||
return True
|
||||
|
||||
# Skip for authenticated staff users
|
||||
if hasattr(request, 'user') and request.user.is_authenticated and request.user.is_staff:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: HttpRequest) -> str:
|
||||
"""Get client IP address with proxy support."""
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
if x_forwarded_for:
|
||||
ip = x_forwarded_for.split(',')[0].strip()
|
||||
else:
|
||||
ip = request.META.get('REMOTE_ADDR', 'unknown')
|
||||
|
||||
# Handle IPv6 loopback
|
||||
if ip == '::1':
|
||||
ip = '127.0.0.1'
|
||||
|
||||
return ip
|
||||
|
||||
def _get_user_id(self, request: HttpRequest) -> Optional[str]:
|
||||
"""Get user ID for rate limiting."""
|
||||
if hasattr(request, 'user') and request.user.is_authenticated:
|
||||
return str(request.user.id)
|
||||
return None
|
||||
|
||||
def _get_endpoint(self, request: HttpRequest) -> str:
|
||||
"""Get endpoint for rate limiting."""
|
||||
return request.path
|
||||
|
||||
def _get_tenant_info(self, request: HttpRequest) -> Dict[str, Any]:
|
||||
"""Get tenant information."""
|
||||
if hasattr(request, 'tenant') and request.tenant:
|
||||
return {
|
||||
'id': request.tenant.id,
|
||||
'name': request.tenant.name,
|
||||
'schema': request.tenant.schema_name
|
||||
}
|
||||
return {'id': None, 'name': 'public', 'schema': 'public'}
|
||||
|
||||
def _should_apply_limit(self, request: HttpRequest, limit_type: str) -> bool:
|
||||
"""Check if limit should be applied to request."""
|
||||
if limit_type == 'api' and request.path.startswith('/api/'):
|
||||
return True
|
||||
elif limit_type == 'login' and '/login' in request.path:
|
||||
return True
|
||||
elif limit_type == 'malaysian_data' and self._is_malaysian_data_endpoint(request):
|
||||
return True
|
||||
elif limit_type == 'file_upload' and request.method == 'POST' and 'upload' in request.path:
|
||||
return True
|
||||
elif limit_type == 'sensitive_operations' and self._is_sensitive_operation(request):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_malaysian_data_endpoint(self, request: HttpRequest) -> bool:
|
||||
"""Check if endpoint accesses Malaysian data."""
|
||||
malaysian_endpoints = [
|
||||
'/api/malaysian/',
|
||||
'/api/ic-validation/',
|
||||
'/api/sst/',
|
||||
'/api/postcode/',
|
||||
'/api/business-registration/',
|
||||
]
|
||||
return any(request.path.startswith(endpoint) for endpoint in malaysian_endpoints)
|
||||
|
||||
def _is_sensitive_operation(self, request: HttpRequest) -> bool:
|
||||
"""Check if operation is sensitive."""
|
||||
sensitive_operations = [
|
||||
'/api/users/',
|
||||
'/api/tenants/',
|
||||
'/api/admin/',
|
||||
'/api/payments/',
|
||||
'/api/export/',
|
||||
]
|
||||
return any(request.path.startswith(op) for op in sensitive_operations)
|
||||
|
||||
def _check_rate_limit(
|
||||
self,
|
||||
ip_address: str,
|
||||
user_id: Optional[str],
|
||||
endpoint: str,
|
||||
tenant: Dict[str, Any],
|
||||
limit_type: str,
|
||||
limit_config: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if rate limit is exceeded."""
|
||||
if not self.redis_client:
|
||||
return False
|
||||
|
||||
# Generate key based on scope
|
||||
if limit_config['scope'] == 'user' and user_id:
|
||||
key = f"rate_limit:{limit_type}:{user_id}:{tenant['id']}"
|
||||
else:
|
||||
key = f"rate_limit:{limit_type}:{ip_address}:{tenant['id']}"
|
||||
|
||||
# Check current count
|
||||
current_count = self.redis_client.get(key)
|
||||
if current_count is None:
|
||||
current_count = 0
|
||||
else:
|
||||
current_count = int(current_count)
|
||||
|
||||
# Check if limit exceeded
|
||||
if current_count >= limit_config['requests']:
|
||||
RATE_LIMIT_EVENTS.labels(
|
||||
type=limit_type,
|
||||
ip_address=ip_address,
|
||||
endpoint=endpoint,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type='rate_limit_exceeded',
|
||||
severity='warning',
|
||||
ip_address=ip_address,
|
||||
user_agent=request.META.get('HTTP_USER_AGENT', 'unknown'),
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
return True
|
||||
|
||||
# Increment counter
|
||||
self.redis_client.incr(key)
|
||||
self.redis_client.expire(key, limit_config['window'])
|
||||
|
||||
return False
|
||||
|
||||
def _create_rate_limit_response(self, limit_type: str, limit_config: Dict[str, Any]) -> JsonResponse:
|
||||
"""Create rate limit response."""
|
||||
response_data = {
|
||||
'error': limit_config['message'],
|
||||
'type': 'rate_limit_exceeded',
|
||||
'retry_after': limit_config['window'],
|
||||
'limit_type': limit_type
|
||||
}
|
||||
|
||||
return JsonResponse(response_data, status=429)
|
||||
|
||||
class InputValidationMiddleware(MiddlewareMixin):
|
||||
"""Input validation and sanitization middleware."""
|
||||
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
self.suspicious_patterns = self._get_suspicious_patterns()
|
||||
self.max_input_size = getattr(settings, 'MAX_INPUT_SIZE', 1024 * 1024) # 1MB
|
||||
|
||||
def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
|
||||
"""Validate and sanitize input."""
|
||||
# Check input size
|
||||
if not self._check_input_size(request):
|
||||
return JsonResponse(
|
||||
{'error': 'Request size too large'},
|
||||
status=413
|
||||
)
|
||||
|
||||
# Validate input for POST/PUT/PATCH
|
||||
if request.method in ['POST', 'PUT', 'PATCH']:
|
||||
if not self._validate_input(request):
|
||||
return JsonResponse(
|
||||
{'error': 'Invalid input detected'},
|
||||
status=400
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_suspicious_patterns(self) -> List[re.Pattern]:
|
||||
"""Get suspicious input patterns."""
|
||||
patterns = [
|
||||
# SQL injection
|
||||
re.compile(r'(?i)\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|EXEC|ALTER|CREATE|TRUNCATE)\b.*\b(FROM|INTO|TABLE|DATABASE)\b'),
|
||||
re.compile(r'(?i)\b(OR\s+1\s*=\s*1|OR\s+TRUE|AND\s+1\s*=\s*1)\b'),
|
||||
re.compile(r'(?i)\b(WAITFOR\s+DELAY|SLEEP\(|PG_SLEEP\(|BENCHMARK\()\b'),
|
||||
|
||||
# XSS
|
||||
re.compile(r'(?i)<(script|iframe|object|embed|form|input)\b.*?>'),
|
||||
re.compile(r'(?i)javascript:'),
|
||||
re.compile(r'(?i)on\w+\s*='),
|
||||
re.compile(r'(?i)(eval|Function|setTimeout|setInterval)\s*\('),
|
||||
|
||||
# Path traversal
|
||||
re.compile(r'\.\./'),
|
||||
re.compile(r'(?i)\b(/etc|/var|/usr|/home)/'),
|
||||
re.compile(r'(?i)\.(htaccess|htpasswd|env)\b'),
|
||||
|
||||
# Command injection
|
||||
re.compile(r'(?i);\s*(rm|ls|cat|pwd|whoami|id|ps|netstat|curl|wget)\s'),
|
||||
re.compile(r'(?i)\|(\s*)(rm|ls|cat|pwd|whoami|id|ps|netstat|curl|wget)\s'),
|
||||
re.compile(r'(?i)&(\s*)(rm|ls|cat|pwd|whoami|id|ps|netstat|curl|wget)\s'),
|
||||
|
||||
# NoSQL injection
|
||||
re.compile(r'(?i)\$where\b'),
|
||||
re.compile(r'(?i)\b(db\.eval|mapReduce|group)\b'),
|
||||
|
||||
# LDAP injection
|
||||
re.compile(r'(?i)\*\)\)(\|\()'),
|
||||
re.compile(r'(?i)\)\)(\|\()'),
|
||||
|
||||
# XML injection
|
||||
re.compile(r'<!ENTITY\s+'),
|
||||
re.compile(r'<\?xml\s+'),
|
||||
]
|
||||
|
||||
return patterns
|
||||
|
||||
def _check_input_size(self, request: HttpRequest) -> bool:
|
||||
"""Check if request size is within limits."""
|
||||
# Check GET parameters
|
||||
for key, value in request.GET.items():
|
||||
if len(str(value)) > self.max_input_size:
|
||||
return False
|
||||
|
||||
# Check POST data
|
||||
if hasattr(request, 'POST'):
|
||||
for key, value in request.POST.items():
|
||||
if len(str(value)) > self.max_input_size:
|
||||
return False
|
||||
|
||||
# Check JSON body
|
||||
if hasattr(request, 'body') and request.body:
|
||||
if len(request.body) > self.max_input_size:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_input(self, request: HttpRequest) -> bool:
|
||||
"""Validate request input for malicious patterns."""
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', 'unknown')
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
# Check GET parameters
|
||||
for key, value in request.GET.items():
|
||||
if self._contains_suspicious_pattern(str(value)):
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type='suspicious_input',
|
||||
severity='warning',
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
return False
|
||||
|
||||
# Check POST data
|
||||
if hasattr(request, 'POST'):
|
||||
for key, value in request.POST.items():
|
||||
if self._contains_suspicious_pattern(str(value)):
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type='suspicious_input',
|
||||
severity='warning',
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
return False
|
||||
|
||||
# Check JSON body
|
||||
if hasattr(request, 'body') and request.body:
|
||||
try:
|
||||
body_str = request.body.decode('utf-8')
|
||||
if self._contains_suspicious_pattern(body_str):
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type='suspicious_input',
|
||||
severity='warning',
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type='invalid_encoding',
|
||||
severity='warning',
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _contains_suspicious_pattern(self, input_str: str) -> bool:
|
||||
"""Check if input contains suspicious patterns."""
|
||||
for pattern in self.suspicious_patterns:
|
||||
if pattern.search(input_str):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: HttpRequest) -> str:
|
||||
"""Get client IP address."""
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
if x_forwarded_for:
|
||||
return x_forwarded_for.split(',')[0].strip()
|
||||
return request.META.get('REMOTE_ADDR', 'unknown')
|
||||
|
||||
def _get_tenant_info(self, request: HttpRequest) -> Dict[str, Any]:
|
||||
"""Get tenant information."""
|
||||
if hasattr(request, 'tenant') and request.tenant:
|
||||
return {
|
||||
'id': request.tenant.id,
|
||||
'name': request.tenant.name,
|
||||
'schema': request.tenant.schema_name
|
||||
}
|
||||
return {'id': None, 'name': 'public', 'schema': 'public'}
|
||||
|
||||
class DataProtectionMiddleware(MiddlewareMixin):
|
||||
"""Malaysian data protection compliance middleware."""
|
||||
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
self.sensitive_data_fields = self._get_sensitive_data_fields()
|
||||
self.required_consent_version = getattr(settings, 'REQUIRED_CONSENT_VERSION', '1.0')
|
||||
|
||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||
"""Process response for data protection."""
|
||||
# Add Malaysian data protection headers
|
||||
response['X-Malaysian-Data-Protection'] = 'PDPA-Compliant'
|
||||
response['X-Data-Residency'] = 'Malaysia'
|
||||
|
||||
# Log Malaysian data access
|
||||
if self._is_malaysian_data_access(request):
|
||||
self._log_malaysian_data_access(request)
|
||||
|
||||
# Sanitize response data
|
||||
if hasattr(response, 'data') and isinstance(response.data, dict):
|
||||
response.data = self._sanitize_response_data(response.data)
|
||||
|
||||
return response
|
||||
|
||||
def _get_sensitive_data_fields(self) -> List[str]:
|
||||
"""Get sensitive data fields that require protection."""
|
||||
return [
|
||||
'ic_number',
|
||||
'passport_number',
|
||||
'email',
|
||||
'phone_number',
|
||||
'address',
|
||||
'bank_account',
|
||||
'salary',
|
||||
'business_registration_number',
|
||||
'tax_id',
|
||||
]
|
||||
|
||||
def _is_malaysian_data_access(self, request: HttpRequest) -> bool:
|
||||
"""Check if request accesses Malaysian data."""
|
||||
malaysian_endpoints = [
|
||||
'/api/malaysian/',
|
||||
'/api/ic-validation/',
|
||||
'/api/sst/',
|
||||
'/api/postcode/',
|
||||
'/api/business-registration/',
|
||||
]
|
||||
|
||||
return any(request.path.startswith(endpoint) for endpoint in malaysian_endpoints)
|
||||
|
||||
def _log_malaysian_data_access(self, request: HttpRequest):
|
||||
"""Log Malaysian data access for compliance."""
|
||||
user_role = 'anonymous'
|
||||
if hasattr(request, 'user') and request.user.is_authenticated:
|
||||
user_role = request.user.role
|
||||
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
# Determine data type
|
||||
data_type = 'unknown'
|
||||
if '/ic-validation/' in request.path:
|
||||
data_type = 'ic_data'
|
||||
elif '/sst/' in request.path:
|
||||
data_type = 'tax_data'
|
||||
elif '/postcode/' in request.path:
|
||||
data_type = 'location_data'
|
||||
elif '/business-registration/' in request.path:
|
||||
data_type = 'business_data'
|
||||
|
||||
MALAYSIAN_DATA_ACCESS.labels(
|
||||
data_type=data_type,
|
||||
operation=request.method,
|
||||
user_role=user_role,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
def _sanitize_response_data(self, data: Any) -> Any:
|
||||
"""Sanitize response data to remove sensitive information."""
|
||||
if isinstance(data, dict):
|
||||
sanitized = {}
|
||||
for key, value in data.items():
|
||||
if key.lower() in [field.lower() for field in self.sensitive_data_fields]:
|
||||
sanitized[key] = self._mask_sensitive_data(key, value)
|
||||
else:
|
||||
sanitized[key] = self._sanitize_response_data(value)
|
||||
return sanitized
|
||||
elif isinstance(data, list):
|
||||
return [self._sanitize_response_data(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
def _mask_sensitive_data(self, field: str, value: Any) -> str:
|
||||
"""Mask sensitive data for logging/display."""
|
||||
if field.lower() in ['ic_number', 'passport_number']:
|
||||
return value[:2] + '*' * (len(value) - 4) + value[-2:]
|
||||
elif field.lower() in ['email']:
|
||||
return value[:3] + '*' * (len(value.split('@')[0]) - 3) + '@' + value.split('@')[1]
|
||||
elif field.lower() in ['phone_number']:
|
||||
return value[:3] + '*' * (len(value) - 6) + value[-3:]
|
||||
elif field.lower() in ['bank_account']:
|
||||
return '*' * (len(value) - 4) + value[-4:]
|
||||
else:
|
||||
return '*' * len(str(value))
|
||||
|
||||
def _get_tenant_info(self, request: HttpRequest) -> Dict[str, Any]:
|
||||
"""Get tenant information."""
|
||||
if hasattr(request, 'tenant') and request.tenant:
|
||||
return {
|
||||
'id': request.tenant.id,
|
||||
'name': request.tenant.name,
|
||||
'schema': request.tenant.schema_name
|
||||
}
|
||||
return {'id': None, 'name': 'public', 'schema': 'public'}
|
||||
|
||||
class SecurityLoggingMiddleware(MiddlewareMixin):
|
||||
"""Security event logging middleware."""
|
||||
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
self.security_log_fields = [
|
||||
'ip_address',
|
||||
'user_agent',
|
||||
'timestamp',
|
||||
'endpoint',
|
||||
'method',
|
||||
'user_id',
|
||||
'tenant',
|
||||
'event_type',
|
||||
'severity',
|
||||
'details'
|
||||
]
|
||||
|
||||
def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
|
||||
"""Log security-relevant requests."""
|
||||
# Log authentication attempts
|
||||
if '/login' in request.path or '/auth/' in request.path:
|
||||
self._log_auth_attempt(request)
|
||||
|
||||
# Log admin access
|
||||
if '/admin/' in request.path:
|
||||
self._log_admin_access(request)
|
||||
|
||||
# Log Malaysian data access
|
||||
if self._is_malaysian_data_access(request):
|
||||
self._log_malaysian_access(request)
|
||||
|
||||
return None
|
||||
|
||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||
"""Log security-relevant responses."""
|
||||
# Log failed requests
|
||||
if response.status_code >= 400:
|
||||
self._log_failed_request(request, response)
|
||||
|
||||
# Log rate limiting
|
||||
if response.status_code == 429:
|
||||
self._log_rate_limit(request, response)
|
||||
|
||||
return response
|
||||
|
||||
def _log_auth_attempt(self, request: HttpRequest):
|
||||
"""Log authentication attempt."""
|
||||
event_type = 'login_attempt'
|
||||
severity = 'info'
|
||||
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', 'unknown')
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
def _log_admin_access(self, request: HttpRequest):
|
||||
"""Log admin area access."""
|
||||
if not hasattr(request, 'user') or not request.user.is_authenticated:
|
||||
event_type = 'unauthorized_admin_access'
|
||||
severity = 'warning'
|
||||
elif not request.user.is_staff:
|
||||
event_type = 'unauthorized_admin_access'
|
||||
severity = 'warning'
|
||||
else:
|
||||
event_type = 'admin_access'
|
||||
severity = 'info'
|
||||
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', 'unknown')
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
def _log_malaysian_access(self, request: HttpRequest):
|
||||
"""Log Malaysian data access."""
|
||||
event_type = 'malaysian_data_access'
|
||||
severity = 'info'
|
||||
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', 'unknown')
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
def _log_failed_request(self, request: HttpRequest, response: HttpResponse):
|
||||
"""Log failed requests."""
|
||||
event_type = 'failed_request'
|
||||
severity = 'warning' if response.status_code < 500 else 'error'
|
||||
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', 'unknown')
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
SECURITY_EVENTS.labels(
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
def _log_rate_limit(self, request: HttpRequest, response: HttpResponse):
|
||||
"""Log rate limiting events."""
|
||||
event_type = 'rate_limit'
|
||||
severity = 'warning'
|
||||
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.META.get('HTTP_USER_AGENT', 'unknown')
|
||||
tenant = self._get_tenant_info(request)
|
||||
|
||||
RATE_LIMIT_EVENTS.labels(
|
||||
type='api',
|
||||
ip_address=ip_address,
|
||||
endpoint=request.path,
|
||||
tenant=tenant.get('name', 'unknown')
|
||||
).inc()
|
||||
|
||||
def _get_client_ip(self, request: HttpRequest) -> str:
|
||||
"""Get client IP address."""
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
if x_forwarded_for:
|
||||
return x_forwarded_for.split(',')[0].strip()
|
||||
return request.META.get('REMOTE_ADDR', 'unknown')
|
||||
|
||||
def _get_tenant_info(self, request: HttpRequest) -> Dict[str, Any]:
|
||||
"""Get tenant information."""
|
||||
if hasattr(request, 'tenant') and request.tenant:
|
||||
return {
|
||||
'id': request.tenant.id,
|
||||
'name': request.tenant.name,
|
||||
'schema': request.tenant.schema_name
|
||||
}
|
||||
return {'id': None, 'name': 'public', 'schema': 'public'}
|
||||
|
||||
def _is_malaysian_data_access(self, request: HttpRequest) -> bool:
|
||||
"""Check if request accesses Malaysian data."""
|
||||
malaysian_endpoints = [
|
||||
'/api/malaysian/',
|
||||
'/api/ic-validation/',
|
||||
'/api/sst/',
|
||||
'/api/postcode/',
|
||||
'/api/business-registration/',
|
||||
]
|
||||
|
||||
return any(request.path.startswith(endpoint) for endpoint in malaysian_endpoints)
|
||||
1288
backend/security/pdpa_compliance.py
Normal file
1288
backend/security/pdpa_compliance.py
Normal file
File diff suppressed because it is too large
Load Diff
1919
backend/security/security_testing.py
Normal file
1919
backend/security/security_testing.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user