Files
multitenetsaas/backend/security/api_security.py
AHMET YILMAZ b3fff546e9
Some checks failed
System Monitoring / Health Checks (push) Has been cancelled
System Monitoring / Performance Monitoring (push) Has been cancelled
System Monitoring / Database Monitoring (push) Has been cancelled
System Monitoring / Cache Monitoring (push) Has been cancelled
System Monitoring / Log Monitoring (push) Has been cancelled
System Monitoring / Resource Monitoring (push) Has been cancelled
System Monitoring / Uptime Monitoring (push) Has been cancelled
System Monitoring / Backup Monitoring (push) Has been cancelled
System Monitoring / Security Monitoring (push) Has been cancelled
System Monitoring / Monitoring Dashboard (push) Has been cancelled
System Monitoring / Alerting (push) Has been cancelled
Security Scanning / Dependency Scanning (push) Has been cancelled
Security Scanning / Code Security Scanning (push) Has been cancelled
Security Scanning / Secrets Scanning (push) Has been cancelled
Security Scanning / Container Security Scanning (push) Has been cancelled
Security Scanning / Compliance Checking (push) Has been cancelled
Security Scanning / Security Dashboard (push) Has been cancelled
Security Scanning / Security Remediation (push) Has been cancelled
project initialization
2025-10-05 02:37:33 +08:00

926 lines
29 KiB
Python

"""
Enhanced API security with JWT, OAuth, and Malaysian-specific authentication.
"""
import json
import secrets
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.cache import cache
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.utils import timezone
from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_http_methods
from rest_framework import status
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from jose import jwt, JWTError
from jose.exceptions import ExpiredSignatureError, JWTClaimsError
import logging
import hashlib
import hmac
from urllib.parse import urlparse
import requests
from requests.auth import HTTPBasicAuth
from .auth import SecurePasswordManager, SecureSessionManager
from .middleware import SecurityLoggingMiddleware
logger = logging.getLogger(__name__)
User = get_user_model()
class MalaysianBusinessAuthenticator:
"""
Malaysian business authentication and authorization system.
"""
def __init__(self):
self.password_manager = SecurePasswordManager()
self.session_manager = SecureSessionManager()
self.logger = logging.getLogger('security.auth')
def authenticate_business(self, request: HttpRequest) -> Optional[Dict]:
"""
Authenticate Malaysian business with enhanced security.
"""
try:
# Get credentials
business_registration = request.data.get('business_registration')
password = request.data.get('password')
ssm_code = request.data.get('ssm_code')
# Validate required fields
if not all([business_registration, password, ssm_code]):
return None
# Verify SSM code format
if not self._validate_ssm_code(ssm_code):
self.logger.warning(f"Invalid SSM code format: {ssm_code}")
return None
# Get user by business registration
try:
user = User.objects.get(
business_registration__registration_number=business_registration,
is_active=True
)
except User.DoesNotExist:
self.logger.warning(f"Business not found: {business_registration}")
return None
# Verify password
if not self.password_manager.verify_password(password, user.password):
self.logger.warning(f"Invalid password for business: {business_registration}")
return None
# Verify SSM code
if not self._verify_ssm_code(user, ssm_code):
self.logger.warning(f"Invalid SSM code for business: {business_registration}")
return None
# Check business status
if not self._check_business_status(user):
self.logger.warning(f"Business not active: {business_registration}")
return None
# Update last login
user.last_login = timezone.now()
user.save()
return {
'user': user,
'business_registration': business_registration,
'ssm_code': ssm_code,
}
except Exception as e:
self.logger.error(f"Authentication error: {e}")
return None
def _validate_ssm_code(self, ssm_code: str) -> bool:
"""
Validate SSM code format.
"""
# SSM code should be 6 digits
if len(ssm_code) != 6:
return False
# Should be numeric
if not ssm_code.isdigit():
return False
return True
def _verify_ssm_code(self, user: User, ssm_code: str) -> bool:
"""
Verify SSM code against stored hash.
"""
try:
# Get stored SSM code hash
stored_hash = getattr(user, 'ssm_code_hash', None)
if not stored_hash:
return False
# Verify hash
return self.password_manager.verify_password(ssm_code, stored_hash)
except Exception as e:
self.logger.error(f"SSM code verification error: {e}")
return False
def _check_business_status(self, user: User) -> bool:
"""
Check if business is active and compliant.
"""
try:
# Check if user is active
if not user.is_active:
return False
# Check business registration status
business_registration = getattr(user, 'business_registration', None)
if not business_registration:
return False
# Check if registration is active
if not business_registration.is_active:
return False
# Check compliance status
if not business_registration.is_compliant:
return False
return True
except Exception as e:
self.logger.error(f"Business status check error: {e}")
return False
class SecureJWTAuthentication:
"""
Enhanced JWT authentication with Malaysian security considerations.
"""
def __init__(self):
self.secret_key = settings.SECRET_KEY
self.algorithm = 'HS256'
self.access_token_lifetime = getattr(settings, 'SIMPLE_JWT', {}).get('ACCESS_TOKEN_LIFETIME', timedelta(minutes=15))
self.refresh_token_lifetime = getattr(settings, 'SIMPLE_JWT', {}).get('REFRESH_TOKEN_LIFETIME', timedelta(days=7))
self.logger = logging.getLogger('security.jwt')
def generate_token_pair(self, user: User) -> Dict[str, str]:
"""
Generate secure JWT token pair.
"""
try:
# Generate access token
access_token = self._generate_access_token(user)
# Generate refresh token
refresh_token = self._generate_refresh_token(user)
return {
'access_token': access_token,
'refresh_token': refresh_token,
'token_type': 'Bearer',
'expires_in': int(self.access_token_lifetime.total_seconds()),
}
except Exception as e:
self.logger.error(f"Token generation error: {e}")
raise
def _generate_access_token(self, user: User) -> str:
"""
Generate access token with enhanced claims.
"""
now = datetime.utcnow()
# Basic claims
claims = {
'user_id': user.id,
'username': user.username,
'email': user.email,
'business_registration': getattr(user, 'business_registration', {}).get('registration_number', ''),
'exp': now + self.access_token_lifetime,
'iat': now,
'jti': secrets.token_urlsafe(16),
'type': 'access',
}
# Add Malaysian-specific claims
claims.update({
'malaysian_business': True,
'business_type': getattr(user, 'business_registration', {}).get('business_type', ''),
'state': getattr(user, 'business_registration', {}).get('state', ''),
'compliance_status': getattr(user, 'business_registration', {}).get('compliance_status', ''),
})
# Add security claims
claims.update({
'device_fingerprint': self._get_device_fingerprint(),
'ip_address': self._get_client_ip(),
'session_id': self._generate_session_id(),
})
return jwt.encode(claims, self.secret_key, algorithm=self.algorithm)
def _generate_refresh_token(self, user: User) -> str:
"""
Generate refresh token with enhanced claims.
"""
now = datetime.utcnow()
claims = {
'user_id': user.id,
'exp': now + self.refresh_token_lifetime,
'iat': now,
'jti': secrets.token_urlsafe(16),
'type': 'refresh',
}
return jwt.encode(claims, self.secret_key, algorithm=self.algorithm)
def verify_token(self, token: str, token_type: str = 'access') -> Optional[Dict]:
"""
Verify JWT token with enhanced validation.
"""
try:
# Decode token
claims = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# Validate token type
if claims.get('type') != token_type:
self.logger.warning(f"Invalid token type: {token_type}")
return None
# Validate claims
if not self._validate_claims(claims):
return None
# Check if token is revoked
if self._is_token_revoked(claims):
self.logger.warning("Token revoked")
return None
return claims
except ExpiredSignatureError:
self.logger.warning("Token expired")
return None
except JWTClaimsError as e:
self.logger.warning(f"Invalid claims: {e}")
return None
except JWTError as e:
self.logger.warning(f"JWT error: {e}")
return None
def _validate_claims(self, claims: Dict) -> bool:
"""
Validate token claims.
"""
# Check required claims
required_claims = ['user_id', 'jti', 'type', 'iat', 'exp']
for claim in required_claims:
if claim not in claims:
self.logger.warning(f"Missing required claim: {claim}")
return False
# Validate user exists
try:
user = User.objects.get(id=claims['user_id'])
if not user.is_active:
self.logger.warning("User not active")
return False
except User.DoesNotExist:
self.logger.warning("User not found")
return False
# Validate device fingerprint if present
if 'device_fingerprint' in claims:
if claims['device_fingerprint'] != self._get_device_fingerprint():
self.logger.warning("Device fingerprint mismatch")
return False
return True
def _is_token_revoked(self, claims: Dict) -> bool:
"""
Check if token is revoked.
"""
try:
jti = claims.get('jti')
if not jti:
return True
# Check cache for revoked tokens
cache_key = f"revoked_token_{jti}"
return cache.get(cache_key) is not None
except Exception as e:
self.logger.error(f"Token revocation check error: {e}")
return True
def revoke_token(self, claims: Dict):
"""
Revoke token by adding to revoked list.
"""
try:
jti = claims.get('jti')
if not jti:
return
# Add to revoked tokens cache
cache_key = f"revoked_token_{jti}"
exp = claims.get('exp')
if exp:
# Cache until token expires
ttl = max(0, exp - int(time.time()))
cache.set(cache_key, True, timeout=ttl)
except Exception as e:
self.logger.error(f"Token revocation error: {e}")
def _get_device_fingerprint(self) -> str:
"""
Generate device fingerprint.
"""
# Simple device fingerprint based on user agent and IP
import hashlib
user_agent = "unknown" # Would get from request
ip_address = self._get_client_ip()
fingerprint = f"{user_agent}:{ip_address}"
return hashlib.sha256(fingerprint.encode()).hexdigest()
def _get_client_ip(self) -> str:
"""
Get client IP address.
"""
# This would get from request headers
return "127.0.0.1" # Placeholder
def _generate_session_id(self) -> str:
"""
Generate session ID.
"""
return secrets.token_urlsafe(32)
class OAuth2Provider:
"""
OAuth2 provider integration for Malaysian services.
"""
def __init__(self):
self.providers = getattr(settings, 'OAUTH2_PROVIDERS', {})
self.logger = logging.getLogger('security.oauth2')
def authenticate_with_oauth2(self, provider: str, code: str, redirect_uri: str) -> Optional[Dict]:
"""
Authenticate using OAuth2 provider.
"""
try:
if provider not in self.providers:
self.logger.warning(f"Unknown OAuth2 provider: {provider}")
return None
provider_config = self.providers[provider]
# Exchange code for access token
token_response = self._exchange_code_for_token(
provider_config, code, redirect_uri
)
if not token_response:
return None
# Get user info
user_info = self._get_user_info(provider_config, token_response['access_token'])
if not user_info:
return None
# Find or create user
user = self._find_or_create_user(provider, user_info)
if not user:
return None
return {
'user': user,
'provider': provider,
'access_token': token_response['access_token'],
'refresh_token': token_response.get('refresh_token'),
'expires_in': token_response.get('expires_in'),
}
except Exception as e:
self.logger.error(f"OAuth2 authentication error: {e}")
return None
def _exchange_code_for_token(self, provider_config: Dict, code: str, redirect_uri: str) -> Optional[Dict]:
"""
Exchange authorization code for access token.
"""
try:
token_url = provider_config['token_url']
client_id = provider_config['client_id']
client_secret = provider_config['client_secret']
data = {
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': redirect_uri,
'client_id': client_id,
'client_secret': client_secret,
}
response = requests.post(token_url, data=data, timeout=10)
if response.status_code != 200:
self.logger.warning(f"Token exchange failed: {response.status_code}")
return None
return response.json()
except Exception as e:
self.logger.error(f"Token exchange error: {e}")
return None
def _get_user_info(self, provider_config: Dict, access_token: str) -> Optional[Dict]:
"""
Get user info from OAuth2 provider.
"""
try:
user_info_url = provider_config['user_info_url']
headers = {'Authorization': f'Bearer {access_token}'}
response = requests.get(user_info_url, headers=headers, timeout=10)
if response.status_code != 200:
self.logger.warning(f"User info request failed: {response.status_code}")
return None
return response.json()
except Exception as e:
self.logger.error(f"User info request error: {e}")
return None
def _find_or_create_user(self, provider: str, user_info: Dict) -> Optional[User]:
"""
Find or create user from OAuth2 provider info.
"""
try:
# Try to find user by email
email = user_info.get('email')
if not email:
self.logger.warning("No email in user info")
return None
try:
user = User.objects.get(email=email)
# Update OAuth2 info
if not hasattr(user, 'oauth2_providers'):
user.oauth2_providers = {}
user.oauth2_providers[provider] = user_info
user.save()
return user
except User.DoesNotExist:
# Create new user
user = User.objects.create(
username=user_info.get('email', ''),
email=email,
first_name=user_info.get('given_name', ''),
last_name=user_info.get('family_name', ''),
is_active=True,
oauth2_providers={provider: user_info}
)
return user
except Exception as e:
self.logger.error(f"User creation error: {e}")
return None
class APIRateLimiter:
"""
Enhanced API rate limiting with Malaysian considerations.
"""
def __init__(self):
self.logger = logging.getLogger('security.ratelimit')
# Rate limits
self.rate_limits = {
'default': {'requests': 100, 'window': 60}, # 100 requests per minute
'login': {'requests': 5, 'window': 60}, # 5 login attempts per minute
'api': {'requests': 1000, 'window': 60}, # 1000 API requests per minute
'upload': {'requests': 10, 'window': 60}, # 10 uploads per minute
'export': {'requests': 20, 'window': 60}, # 20 exports per minute
}
# Malaysian business limits
self.malaysian_business_limits = {
'sst_calculation': {'requests': 50, 'window': 60},
'ic_validation': {'requests': 30, 'window': 60},
'postcode_lookup': {'requests': 100, 'window': 60},
}
def check_rate_limit(self, key: str, endpoint: str = 'default') -> bool:
"""
Check if request is within rate limit.
"""
try:
# Get limit for endpoint
if endpoint in self.malaysian_business_limits:
limit = self.malaysian_business_limits[endpoint]
elif endpoint in self.rate_limits:
limit = self.rate_limits[endpoint]
else:
limit = self.rate_limits['default']
# Check cache
cache_key = f"rate_limit_{key}_{endpoint}"
current = cache.get(cache_key, 0)
if current >= limit['requests']:
self.logger.warning(f"Rate limit exceeded for {key} on {endpoint}")
return False
# Increment counter
cache.set(cache_key, current + 1, timeout=limit['window'])
return True
except Exception as e:
self.logger.error(f"Rate limit check error: {e}")
return True # Allow on error
def get_rate_limit_info(self, key: str, endpoint: str = 'default') -> Dict:
"""
Get current rate limit information.
"""
try:
# Get limit for endpoint
if endpoint in self.malaysian_business_limits:
limit = self.malaysian_business_limits[endpoint]
elif endpoint in self.rate_limits:
limit = self.rate_limits[endpoint]
else:
limit = self.rate_limits['default']
# Get current count
cache_key = f"rate_limit_{key}_{endpoint}"
current = cache.get(cache_key, 0)
ttl = cache.ttl(cache_key)
if ttl < 0:
ttl = limit['window']
return {
'limit': limit['requests'],
'remaining': max(0, limit['requests'] - current),
'reset': ttl,
'window': limit['window'],
}
except Exception as e:
self.logger.error(f"Rate limit info error: {e}")
return {'limit': 0, 'remaining': 0, 'reset': 0, 'window': 0}
# Authentication Views
class MalaysianBusinessTokenView(TokenObtainPairView):
"""
Custom token view for Malaysian business authentication.
"""
serializer_class = TokenObtainPairSerializer
def post(self, request, *args, **kwargs):
try:
# Initialize authenticator
authenticator = MalaysianBusinessAuthenticator()
# Authenticate business
auth_result = authenticator.authenticate_business(request)
if not auth_result:
return Response(
{'error': 'Invalid credentials'},
status=status.HTTP_401_UNAUTHORIZED
)
# Initialize JWT authentication
jwt_auth = SecureJWTAuthentication()
# Generate tokens
tokens = jwt_auth.generate_token_pair(auth_result['user'])
# Log successful authentication
logger.info(f"Successful authentication for business: {auth_result['business_registration']}")
return Response(tokens, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Authentication error: {e}")
return Response(
{'error': 'Authentication failed'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def refresh_token_view(request):
"""
Refresh JWT token.
"""
try:
refresh_token = request.data.get('refresh_token')
if not refresh_token:
return Response(
{'error': 'Refresh token required'},
status=status.HTTP_400_BAD_REQUEST
)
# Initialize JWT authentication
jwt_auth = SecureJWTAuthentication()
# Verify refresh token
claims = jwt_auth.verify_token(refresh_token, 'refresh')
if not claims:
return Response(
{'error': 'Invalid refresh token'},
status=status.HTTP_401_UNAUTHORIZED
)
# Get user
try:
user = User.objects.get(id=claims['user_id'])
except User.DoesNotExist:
return Response(
{'error': 'User not found'},
status=status.HTTP_404_NOT_FOUND
)
# Generate new token pair
new_tokens = jwt_auth.generate_token_pair(user)
# Revoke old refresh token
jwt_auth.revoke_token(claims)
return Response(new_tokens, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Token refresh error: {e}")
return Response(
{'error': 'Token refresh failed'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def revoke_token_view(request):
"""
Revoke JWT token.
"""
try:
token = request.data.get('token')
if not token:
return Response(
{'error': 'Token required'},
status=status.HTTP_400_BAD_REQUEST
)
# Initialize JWT authentication
jwt_auth = SecureJWTAuthentication()
# Verify token
claims = jwt_auth.verify_token(token)
if not claims:
return Response(
{'error': 'Invalid token'},
status=status.HTTP_401_UNAUTHORIZED
)
# Revoke token
jwt_auth.revoke_token(claims)
return Response({'message': 'Token revoked'}, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Token revocation error: {e}")
return Response(
{'error': 'Token revocation failed'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@api_view(['POST'])
def oauth2_login_view(request):
"""
OAuth2 login view.
"""
try:
provider = request.data.get('provider')
code = request.data.get('code')
redirect_uri = request.data.get('redirect_uri')
if not all([provider, code, redirect_uri]):
return Response(
{'error': 'Missing required fields'},
status=status.HTTP_400_BAD_REQUEST
)
# Initialize OAuth2 provider
oauth2_provider = OAuth2Provider()
# Authenticate with OAuth2
auth_result = oauth2_provider.authenticate_with_oauth2(provider, code, redirect_uri)
if not auth_result:
return Response(
{'error': 'OAuth2 authentication failed'},
status=status.HTTP_401_UNAUTHORIZED
)
# Initialize JWT authentication
jwt_auth = SecureJWTAuthentication()
# Generate tokens
tokens = jwt_auth.generate_token_pair(auth_result['user'])
# Log successful authentication
logger.info(f"Successful OAuth2 authentication for user: {auth_result['user'].username}")
return Response(tokens, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"OAuth2 login error: {e}")
return Response(
{'error': 'OAuth2 authentication failed'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
# Middleware for API security
class APISecurityMiddleware:
"""
Middleware for API security with rate limiting and authentication.
"""
def __init__(self, get_response):
self.get_response = get_response
self.rate_limiter = APIRateLimiter()
self.jwt_auth = SecureJWTAuthentication()
self.logger = logging.getLogger('security.api')
def __call__(self, request):
# Skip for certain endpoints
if self._should_skip(request):
return self.get_response(request)
# Check rate limit
if not self._check_rate_limit(request):
return JsonResponse(
{'error': 'Rate limit exceeded'},
status=status.HTTP_429_TOO_MANY_REQUESTS
)
# Check authentication
if self._requires_auth(request):
auth_result = self._check_authentication(request)
if not auth_result:
return JsonResponse(
{'error': 'Authentication required'},
status=status.HTTP_401_UNAUTHORIZED
)
# Add user to request
request.user = auth_result['user']
request.auth_claims = auth_result['claims']
return self.get_response(request)
def _should_skip(self, request) -> bool:
"""
Determine if security checks should be skipped.
"""
skip_paths = [
'/health/',
'/metrics/',
'/static/',
'/media/',
'/api/v1/docs/',
'/api/v1/oauth2/',
]
return any(request.path.startswith(path) for path in skip_paths)
def _check_rate_limit(self, request) -> bool:
"""
Check rate limit for the request.
"""
try:
# Get rate limit key
if hasattr(request, 'user') and request.user.is_authenticated:
key = f"user_{request.user.id}"
else:
key = f"ip_{self._get_client_ip(request)}"
# Determine endpoint type
endpoint = 'default'
if 'login' in request.path:
endpoint = 'login'
elif 'api' in request.path:
endpoint = 'api'
elif 'upload' in request.path:
endpoint = 'upload'
elif 'export' in request.path:
endpoint = 'export'
elif 'sst' in request.path:
endpoint = 'sst_calculation'
elif 'ic' in request.path:
endpoint = 'ic_validation'
elif 'postcode' in request.path:
endpoint = 'postcode_lookup'
return self.rate_limiter.check_rate_limit(key, endpoint)
except Exception as e:
self.logger.error(f"Rate limit check error: {e}")
return True # Allow on error
def _requires_auth(self, request) -> bool:
"""
Check if request requires authentication.
"""
auth_paths = [
'/api/v1/',
'/business/',
'/admin/',
]
return any(request.path.startswith(path) for path in auth_paths)
def _check_authentication(self, request) -> Optional[Dict]:
"""
Check request authentication.
"""
try:
# Get Authorization header
auth_header = request.META.get('HTTP_AUTHORIZATION')
if not auth_header:
return None
# Parse Bearer token
if not auth_header.startswith('Bearer '):
return None
token = auth_header[7:]
# Verify token
claims = self.jwt_auth.verify_token(token)
if not claims:
return None
# Get user
try:
user = User.objects.get(id=claims['user_id'])
except User.DoesNotExist:
return None
return {
'user': user,
'claims': claims,
}
except Exception as e:
self.logger.error(f"Authentication check error: {e}")
return None
def _get_client_ip(self, request) -> str:
"""
Get client IP address from request.
"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip