Search code examples
asp.net-mvccookiesasp.net-core-mvcmiddleware

Cookie middleware in MVC 6


I am trying to understand the cookie Middleware authentication behavior and its flow. However, I am not able to understand that. Below is my issue

  1. Does cookie has to be set/get with ClaimsIdentity and not using asp.net identity ?
  2. I have implemented Remember Me functionality which creates the cookie, so in this I need to use Cookie middleware OR not ?, for this I am using below code

    var result = await _signInManager.PasswordSignInAsync(model.Email, model.Password, model.RememberMe, lockoutOnFailure: true);

So, which is the preferred way for creating cookie 1 OR 2 point ?

  1. How can I validate user based on cookie using claims identity ?

Any help on this appreciated !


Solution

  • This solution consists of multiple files:

    1. Authentication middleware
    2. Authorization attribute
    3. Cookie management library (Security manager (ISecurityManager))
    4. Security extension library (for middleware wrapper & for claims manipulation)

    Authentication middleware:

        public class CustomAuthentication
        {
        private readonly RequestDelegate _next;
        private readonly ISecurityManager _securityManager = null;
    
        private readonly List<string> _publiclyAccessiblePaths = new List<string> {
            "security/forgotpassword",
            "security/resetpassword"
        };
    
        public BBAuthentication(RequestDelegate next, ISecurityManager securityManager)
        {
            this._next = next;
            this._securityManager = securityManager;
        }
    
        public async Task Invoke(HttpContext context)
        {
            bool authenticated = await _securityManager.ProcessSecurityContext(context);
            bool containsLoginPath = context.Request.Path.Value.Trim().ToLower().Contains("security/login");
            bool containsPublicPaths = _publiclyAccessiblePaths.Any(x => context.Request.Path.Value.Trim().ToLower().Contains(x));
            bool blockPipeline = false;
    
            if(!containsPublicPaths) {  
                if (authenticated) {
                    if (containsLoginPath) {
                        // If the user is authenticated and requests login page, redirect them to home (disallow accessing login page if authenticated)
                        context.Response.StatusCode = 302;
                        context.Response.Redirect("/");
                        blockPipeline = true;
                    }
                } else {
                    if (!containsLoginPath) {
                        context.Response.StatusCode = 401;
                        context.Response.Redirect("/security/login");
                        blockPipeline = true;
                    } 
                }
            }
    
            if (!blockPipeline)
                await _next(context);
       } 
    }
    

    ISecurityManager interface:

     public interface ISecurityManager
    {
        void Login(HttpContext context, UserMetadata userMetadata,bool persistent);
        void Logout(HttpContext context);
        Task<bool> ProcessSecurityContext(HttpContext context);
    
    }
    

    Encryption Library:

    internal class EncryptionLib
    {
        /// <summary>
        /// Encrypt the given string using AES.  The string can be decrypted using 
        /// DecryptStringAES().  The sharedSecret parameters must match.
        /// </summary>
        /// <param name="plainText">The text to encrypt.</param>
        /// <param name="sharedSecret">A password used to generate a key for encryption.</param>
        /// <param name="salt">The key salt used to derive the key.</param>
        /// <exception cref="ArgumentNullException">Text is null or empty.</exception>
        /// <exception cref="ArgumentNullException">Password is null or empty.</exception>
        public static string EncryptStringAES(string plainText, string sharedSecret, byte[] salt)
        {
            if (string.IsNullOrEmpty(plainText))
                throw new ArgumentNullException("plainText", "Text is null or empty.");
            if (string.IsNullOrEmpty(sharedSecret))
                throw new ArgumentNullException("sharedSecret", "Password is null or empty.");
    
            string outStr = null;
            try
            {
                Rfc2898DeriveBytes key = new Rfc2898DeriveBytes(sharedSecret, salt);
                using (var aesAlg = new RijndaelManaged())
                {
                    aesAlg.Key = key.GetBytes(aesAlg.KeySize / 8);
                    aesAlg.IV = key.GetBytes(aesAlg.BlockSize / 8);
    
                    ICryptoTransform encryptor = aesAlg.CreateEncryptor(aesAlg.Key, aesAlg.IV);
                    using (MemoryStream msEncrypt = new MemoryStream())
                    {
                        using (CryptoStream csEncrypt = new CryptoStream(msEncrypt, encryptor, CryptoStreamMode.Write))
                            using (StreamWriter swEncrypt = new StreamWriter(csEncrypt))
                                swEncrypt.Write(plainText);
                        outStr = Convert.ToBase64String(msEncrypt.ToArray());
                    }
                }
            }
            catch { }
            return outStr;
        }
        /// <summary>
        /// Decrypt the given string.  Assumes the string was encrypted using 
        /// EncryptStringAES(), using an identical sharedSecret.
        /// </summary>
        /// <param name="cipherText">The text to decrypt.</param>
        /// <param name="sharedSecret">A password used to generate a key for decryption.</param>
        /// <param name="salt">The key salt used to derive the key.</param>
        /// <exception cref="ArgumentNullException">Text is null or empty.</exception>
        /// <exception cref="ArgumentNullException">Password is null or empty.</exception>
        public static string DecryptStringAES(string cipherText, string sharedSecret, byte[] salt)
        {
            if (string.IsNullOrEmpty(cipherText))
                throw new ArgumentNullException("cipherText", "Text is null or empty.");
            if (string.IsNullOrEmpty(sharedSecret))
                throw new ArgumentNullException("sharedSecret", "Password is null or empty.");
    
            string plaintext = null;
            try
            {
                Rfc2898DeriveBytes key = new Rfc2898DeriveBytes(sharedSecret, salt);
                using (var aesAlg = new RijndaelManaged())
                {
                    aesAlg.Key = key.GetBytes(aesAlg.KeySize / 8);
                    aesAlg.IV = key.GetBytes(aesAlg.BlockSize / 8);
    
                    ICryptoTransform decryptor = aesAlg.CreateDecryptor(aesAlg.Key, aesAlg.IV);
                    byte[] bytes = Convert.FromBase64String(cipherText);
                    using (MemoryStream msDecrypt = new MemoryStream(bytes))
                        using (CryptoStream csDecrypt = new CryptoStream(msDecrypt, decryptor, CryptoStreamMode.Read))
                            using (StreamReader srDecrypt = new StreamReader(csDecrypt))
                                plaintext = srDecrypt.ReadToEnd();
                }
            }
            catch { }
            return plaintext;
        }
    
        public static string HashSHA1(string plainText)
        {
            using (var sha = SHA1Managed.Create())
            {
                return Convert.ToBase64String(sha.ComputeHash(Encoding.ASCII.GetBytes(plainText)));
            }
        }
    
        public static string HashSHA256(string plainText)
        {
            using (var sha = SHA256Managed.Create())
            {
                return Convert.ToBase64String(sha.ComputeHash(Encoding.ASCII.GetBytes(plainText)));
            }
        }
    
        public static string HashHMACMD5(string plainText, byte[] key)
        {
            using (HMACMD5 hmac = new HMACMD5(key))
            {
                var bytes = Encoding.ASCII.GetBytes(plainText);
    
                return Convert.ToBase64String(hmac.ComputeHash(bytes));
            }
        }
    }
    

    SecurityCookie:

    internal class SecurityCookie
    {
        public string Name { get; set; }
        public DateTime ExpiryDate { get; set; }
        public SecurityContext Context { get; set; }
    
        #region Security keys
    
        private static readonly byte[] _cookieSigEncryptionKey = new byte[64] {
            //Enter 64 bytes of encryption keys
        };
    
        private static readonly byte[] _cookiePayloadEncryptionSalt = new byte[48] {
            //Enter 48 bytes of encryption salts
        };
    
        private static readonly string _cookiePayloadEncryptionKey = "here goes your complex encryption password";
    
        #endregion
    
        private static readonly string _cookiePayloadToken = "~~";
    
        public SecurityCookie(string name, DateTime expiryDate, SecurityContext context)
        {
            this.Name = name;
            this.ExpiryDate = expiryDate;
            this.Context = context;
        }
    
        public SecurityCookie()
        {
            this.Name = Guid.NewGuid().ToString();
        }
    
        public KeyValuePair<string,Tuple<CookieOptions,string>> CreateCookie()
        {
            CookieOptions cookieOptions = new CookieOptions() {
                Expires = ExpiryDate,
                HttpOnly = true,
                Secure = false
            };
    
            return new KeyValuePair<string, Tuple<CookieOptions,string>>(Name,new Tuple<CookieOptions,string>(cookieOptions, GenerateContent()));
        }
    
        public static SecurityContext ExtractContent(string encryptedContent)
        {
            string decodedContent = null;
            string decryptedContent = null;
    
            try {
                decodedContent = Encoding.ASCII.GetString(Convert.FromBase64String(encryptedContent));
                decryptedContent = EncryptionLib.DecryptStringAES(decodedContent, _cookiePayloadEncryptionKey, _cookiePayloadEncryptionSalt);
            } catch {
                decryptedContent = null;
            }
    
            if (string.IsNullOrWhiteSpace(decryptedContent))
                return null;
    
            string[] dataParts = decryptedContent.Split(new string[] { _cookiePayloadToken },StringSplitOptions.RemoveEmptyEntries);
    
            if(dataParts == null || dataParts.Length != 2)
                return null;
    
            if (dataParts[1] != Sign(dataParts[0]))
                return null;
    
            return JsonConvert.DeserializeObject<SecurityContext>(dataParts[0]);
        }
    
        public string GenerateContent()
        {
            string data = JsonConvert.SerializeObject(Context);
            string signature = Sign(data);
    
            // _cookiePayloadToken denotes end of payload segment and start of signature (checksum) segment
            return EncryptAndHashCookieContent(data + _cookiePayloadToken + signature);
        }
    
        private string EncryptAndHashCookieContent(string content)
        {
            return Convert.ToBase64String(
                Encoding.ASCII.GetBytes(
                    EncryptionLib.EncryptStringAES(content,_cookiePayloadEncryptionKey,_cookiePayloadEncryptionSalt)
                )
            ); 
        }
    
        private static string Sign(string data)
        {
            return EncryptionLib.HashHMACMD5(data,_cookieSigEncryptionKey);
        }
    }
    

    Security manager:

     public class SecurityManager : ISecurityManager
     {
        private const string _authCookieName = "aa-bm"; 
        private const string _authCookieItemsKey = "security-cookie";
        private const int _authCookieExpiryMinutesPersistent = 60 * 24 * 30;    
        private const int _authCookieExpiryMinutesTransient = 60 * 1;           
        private const int _authCookieSecurityContextRefreshMinutes = 5;
        private const string _securityContextCurrentVersion = "1.0";
    
        private readonly ISecurityService _securityService;
    
        public SecurityManager(ISecurityService securityService)
        {
            this._securityService = securityService;
        }
    
        private string GetSecurityCookieValue(HttpContext context)
        {
            var cookies = context.Request.Cookies[_authCookieName];
    
            if(cookies.Count == 0)
                return null;
    
            return cookies[0];
        }
    
        public async Task<bool> ProcessSecurityContext(HttpContext context)
        {
            string encryptedValue = GetSecurityCookieValue(context);
            if (string.IsNullOrWhiteSpace(encryptedValue))
                return false;
    
            SecurityContext securityContext = ExtractSecurityContext(encryptedValue);
            if (securityContext == null || securityContext.Metadata.UserID <= 0 || securityContext.ContextVersion != _securityContextCurrentVersion) {
                context.Response.Cookies.Delete(_authCookieName);
                return false;           
            }
    
            securityContext = await RefreshCookieContext(context,securityContext);
            if(securityContext == null) {
                context.Response.Cookies.Delete(_authCookieName);
                return false;
            }
    
            ClaimsIdentity identity = new ClaimsIdentity(new[] { new Claim(ClaimTypes.UserData, securityContext.ToString()) });
            context.User.AddIdentity(identity);
    
            return true;
        }
    
        private SecurityContext ExtractSecurityContext(string encryptedValue)
        {
            return SecurityCookie.ExtractContent(encryptedValue);
        }
    
        private async Task<SecurityContext> RefreshCookieContext(HttpContext context,SecurityContext currentContext)
        {
            DateTime expiry = DateTime.Now.AddMinutes(_authCookieExpiryMinutesTransient);
    
            currentContext.Expires = expiry;
            if(currentContext.SecurityDomainContext.RefreshSecurityContextDate < DateTime.Now) {
                UserMetadata userMetadata = await _securityService.GetUserMetadata(currentContext.Metadata.UserID);
                if(userMetadata == null)
                    return null;
                currentContext.Metadata = userMetadata;
                currentContext.SecurityDomainContext.RefreshSecurityContextDate = DateTime.Now.AddMinutes(_authCookieSecurityContextRefreshMinutes);
                currentContext.Status = userMetadata.Status;
            }
    
            SecurityCookie secureCookie = new SecurityCookie(
                _authCookieName,
                expiry,
                currentContext
            );
    
            var cookie = secureCookie.CreateCookie();
    
            context.Response.Cookies.Append(cookie.Key,cookie.Value.Item2,cookie.Value.Item1);
    
            return currentContext;
        }
    
        public void Login(HttpContext context, UserMetadata userMetadata, bool persistent) 
        {
            DateTime expiry = (persistent ? DateTime.Now.AddMinutes(_authCookieExpiryMinutesPersistent) : DateTime.Now.AddMinutes(_authCookieExpiryMinutesTransient));
            SecurityCookie secureCookie = new SecurityCookie(
                _authCookieName,
                expiry,
                new SecurityContext() {
                    Expires = expiry,
                    Metadata = userMetadata,
                    SecurityDomainContext = new SecurityDomainContext() {
                        RefreshSecurityContextDate = DateTime.Now.AddMinutes(_authCookieSecurityContextRefreshMinutes)
                    },
                    ContextVersion = _securityContextCurrentVersion
                }
            );
    
            var cookie = secureCookie.CreateCookie();
            context.Response.Cookies.Append(cookie.Key,cookie.Value.Item2,cookie.Value.Item1);
        }
    
        public void Logout(HttpContext context)
        {
            context.Response.Cookies.Delete(_authCookieName, new CookieOptions() { Expires = DateTime.Now.AddDays(-365) });
        }
    }
    

    Security Extensions:

    public static void UseCustomAuthentication(this IApplicationBuilder builder)
    {
        builder.UseMiddleware<BBAuthentication>();
    }
    
        public static SecurityContext GetUserSecurityContext(this ClaimsPrincipal claimsPrincipal)
        {
            var userDataClaim = claimsPrincipal.Claims.Where(x => x.Type == ClaimTypes.UserData).FirstOrDefault();
    
            if(userDataClaim == null)
                return null;
    
            return JsonConvert.DeserializeObject<SecurityContext>(userDataClaim.Value);
        }
    
        public static bool HasRoles(this ClaimsPrincipal claimsPrincipal, params RoleEnum[] roles)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
    
            foreach(var role in roles)
                if(!context.Metadata.Roles.Contains(role))
                    return false;
    
            return true;
        }
    
        public static bool HasPermissions(this ClaimsPrincipal claimsPrincipal, params PermissionEnum[] permissions)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
    
            foreach(var perm in permissions)
                if(!context.Metadata.Permissions.Contains(perm))
                    return false;
    
            return true;
        }
    
        public static int GetUserID(this ClaimsPrincipal claimsPrincipal)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
            return context.Metadata.UserID;
        }
    
        public static int GetCompanyID(this ClaimsPrincipal claimsPrincipal)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
            return context.Metadata.CompanyID;
        }
    
        public static PersonalSettings GetUserSettings(this ClaimsPrincipal claimsPrincipal)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
            return context.Metadata.Settings;
        }
    
        public static string GetFullName(this ClaimsPrincipal claimsPrincipal)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
            return context.Metadata.FirstName + " " + context.Metadata.LastName;
        }
    
        public static string GetFirstName(this ClaimsPrincipal claimsPrincipal)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
            return context.Metadata.FirstName;
        }
    
        public static string GetLastName(this ClaimsPrincipal claimsPrincipal)
        {
            SecurityContext context = claimsPrincipal.GetUserSecurityContext();
            return context.Metadata.LastName;
        }
    }
    

    Example of permission authorization (Decorate your controller methods or entire controller with this and state the permissions one-by-one. Also there is almost the same RolesAuthorize attribute that does the same thing with roles):

    public class AuthorizePermissionsAttribute : ActionFilterAttribute
    {   
        public PermissionEnum[] Permissions { get; set; }
    
        public BBAuthorizePermissionsAttribute(params PermissionEnum[] permissions)
        {
            this.Permissions = permissions;
        }
    
        public override Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next) 
        {
            bool hasPermissions = context.HttpContext.User.HasPermissions(Permissions);
    
            if(!hasPermissions) {
                throw new UnauthenticatedException();
            }
    
            return base.OnActionExecutionAsync(context, next);
        }
    }
    

    Now for usage.

    You use middleware for authentication and custom attributes for authorization. SecurityManager class is a central point for manipulating security, it serializes and deserializes encrypted data and is using that data on each request (through SeccurityExtensions usage in the controllers). The SecurityContext class is as you need it I suggest encrypting at least 3,4 fields (UserID,CompanyID, UserSettings, Roles,Permissions)

    In Startup.cs in Configure method after useStaticFiles or useIISPlatformHandler immediately write app.UseBBAuthentication(); The ordering of this is very important.

    Unfortunately I cannot write you a detailed description of usage right now since I am very busy. But I think the code is self-explanatory. This works in a production system right now so it is field tested.

    If you have specific questions after thorough analysis please ask away!