diff --git a/api/src/BaseClasses/BaseController.cs b/api/src/BaseClasses/BaseController.cs index d17fa40..590197a 100644 --- a/api/src/BaseClasses/BaseController.cs +++ b/api/src/BaseClasses/BaseController.cs @@ -5,50 +5,24 @@ namespace WinStudentGoalTracker.BaseClasses; public class BaseController : ControllerBase { - protected (Guid userId, ActionResult? error) GetUserIdFromClaims() + protected (Guid userId, string email, Guid programId, string role, ActionResult? error) GetProgramUserFromClaims() { - var userIdClaim = User.FindFirst("user_id")?.Value - ?? User.FindFirst(ClaimTypes.NameIdentifier)?.Value; + var userIdClaim = User.FindFirst("user_id")?.Value; + if (!Guid.TryParse(userIdClaim, out var userId)) + return (Guid.Empty, string.Empty, Guid.Empty, string.Empty, Unauthorized("Missing or invalid user_id claim.")); - if (string.IsNullOrWhiteSpace(userIdClaim) || !Guid.TryParse(userIdClaim, out var userId)) - { - return (Guid.Empty, Unauthorized("Missing or invalid user_id claim.")); - } - - return (userId, null); - } - - protected (string email, List roles, ActionResult? error) GetUserDetailsFromClaims() - { var email = User.FindFirst(ClaimTypes.Email)?.Value; if (string.IsNullOrWhiteSpace(email)) - { - return (string.Empty, new List(), Unauthorized("Missing email claim.")); - } + return (Guid.Empty, string.Empty, Guid.Empty, string.Empty, Unauthorized("Missing email claim.")); - var roles = User.FindAll(ClaimTypes.Role).Select(claim => claim.Value).ToList(); - return (email, roles, null); - } - - protected bool HasRole(string role) - { - return User.IsInRole(role); - } - - protected bool HasAnyRole(params string[] roles) - { - return roles.Any(User.IsInRole); - } - - protected (Guid programId, ActionResult? error) GetProgramIdFromClaims() - { var programIdClaim = User.FindFirst("program_id")?.Value; + if (!Guid.TryParse(programIdClaim, out var programId)) + return (Guid.Empty, string.Empty, Guid.Empty, string.Empty, Unauthorized("Missing or invalid program_id claim.")); - if (string.IsNullOrWhiteSpace(programIdClaim) || !Guid.TryParse(programIdClaim, out var programId)) - { - return (Guid.Empty, Unauthorized("Missing or invalid program_id claim.")); - } + var role = User.FindFirst(ClaimTypes.Role)?.Value; + if (string.IsNullOrWhiteSpace(role)) + return (Guid.Empty, string.Empty, Guid.Empty, string.Empty, Unauthorized("Missing role claim.")); - return (programId, null); + return (userId, email, programId, role, null); } } diff --git a/api/src/Controllers/AuthController.cs b/api/src/Controllers/AuthController.cs index 895204c..0158f44 100644 --- a/api/src/Controllers/AuthController.cs +++ b/api/src/Controllers/AuthController.cs @@ -15,6 +15,7 @@ public class AuthController : BaseController private readonly UserRepository _userRepo = new(); private readonly AuthRepository _authRepo = new(); private readonly TokenService _tokenService; + private static readonly int _loginExpiration = 60 * 60 * 24 * 31; // Refresh token expires after 1 month. public AuthController(TokenService tokenService) { @@ -125,8 +126,9 @@ public class AuthController : BaseController }); } - var (userId, userIdError) = GetUserIdFromClaims(); - if (userIdError != null) return userIdError; + var userIdClaim = User.FindFirst("user_id")?.Value; + if (!Guid.TryParse(userIdClaim, out Guid userId)) + return Unauthorized(new ResponseResult { Success = false, Message = "Invalid session token." }); if (!Guid.TryParse(dto.ProgramId, out Guid programId)) { @@ -177,7 +179,7 @@ public class AuthController : BaseController programUser.IdProgram, refreshTokenHash, refreshTokenSalt, - expiresInSeconds: 2592000, // 30 days + expiresInSeconds: _loginExpiration, deviceInfo: deviceInfo, userAgent: userAgent ); @@ -296,13 +298,13 @@ public class AuthController : BaseController }); } - var newJwtToken = _tokenService.GenerateToken( + var newJwt = _tokenService.GenerateToken( programUser.IdUser, programUser.Email!, programUser.RoleInternalName, programUser.IdProgram); - var jwtExpiresIn = _tokenService.GetTokenExpiryInSeconds(newJwtToken); + var jwtExpiresIn = _tokenService.GetTokenExpiryInSeconds(newJwt); var newSecretToken = Guid.NewGuid().ToString(); var (newRefreshTokenHash, newRefreshTokenSalt) = PasswordHasher.HashPassword(newSecretToken); @@ -340,7 +342,7 @@ public class AuthController : BaseController Message = "Token refreshed successfully.", Data = new TokenRefreshResponse { - Jwt = newJwtToken, + Jwt = newJwt, NewRefreshToken = fullNewRefreshToken, JwtExpiresIn = jwtExpiresIn } @@ -362,8 +364,8 @@ public class AuthController : BaseController }); } - var (userId, error) = GetUserIdFromClaims(); - if (error != null) return error; + var (userId, _, _, _, claimsError) = GetProgramUserFromClaims(); + if (claimsError != null) return claimsError; var dotIndex = logoutDto.RefreshToken.IndexOf('.'); if (dotIndex < 1 || !Guid.TryParse(logoutDto.RefreshToken[..dotIndex], out Guid tokenId))