// Copyright (c) Microsoft Corporation. All rights reserved. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Configuration.Provider;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Web;
using System.Web.Helpers;
using System.Web.Security;
using System.Web.WebPages;
using Microsoft.Internal.Web.Utils;
using WebMatrix.Data;
using WebMatrix.WebData.Resources;
namespace WebMatrix.WebData
{
public class SimpleMembershipProvider : ExtendedMembershipProvider
{
private const int TokenSizeInBytes = 16;
private readonly MembershipProvider _previousProvider;
public SimpleMembershipProvider()
: this(null)
{
}
public SimpleMembershipProvider(MembershipProvider previousProvider)
{
_previousProvider = previousProvider;
if (_previousProvider != null)
{
_previousProvider.ValidatingPassword += (sender, args) =>
{
if (!InitializeCalled)
{
OnValidatingPassword(args);
}
};
}
}
private MembershipProvider PreviousProvider
{
get
{
if (_previousProvider == null)
{
throw new InvalidOperationException(WebDataResources.Security_InitializeMustBeCalledFirst);
}
else
{
return _previousProvider;
}
}
}
// Public properties
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool EnablePasswordRetrieval
{
get { return InitializeCalled ? false : PreviousProvider.EnablePasswordRetrieval; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool EnablePasswordReset
{
get { return InitializeCalled ? false : PreviousProvider.EnablePasswordReset; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool RequiresQuestionAndAnswer
{
get { return InitializeCalled ? false : PreviousProvider.RequiresQuestionAndAnswer; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool RequiresUniqueEmail
{
get { return InitializeCalled ? false : PreviousProvider.RequiresUniqueEmail; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override MembershipPasswordFormat PasswordFormat
{
get { return InitializeCalled ? MembershipPasswordFormat.Hashed : PreviousProvider.PasswordFormat; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override int MaxInvalidPasswordAttempts
{
get { return InitializeCalled ? Int32.MaxValue : PreviousProvider.MaxInvalidPasswordAttempts; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override int PasswordAttemptWindow
{
get { return InitializeCalled ? Int32.MaxValue : PreviousProvider.PasswordAttemptWindow; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override int MinRequiredPasswordLength
{
get { return InitializeCalled ? 0 : PreviousProvider.MinRequiredPasswordLength; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override int MinRequiredNonAlphanumericCharacters
{
get { return InitializeCalled ? 0 : PreviousProvider.MinRequiredNonAlphanumericCharacters; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override string PasswordStrengthRegularExpression
{
get { return InitializeCalled ? String.Empty : PreviousProvider.PasswordStrengthRegularExpression; }
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override string ApplicationName
{
get
{
if (InitializeCalled)
{
throw new NotSupportedException();
}
else
{
return PreviousProvider.ApplicationName;
}
}
set
{
if (InitializeCalled)
{
throw new NotSupportedException();
}
else
{
PreviousProvider.ApplicationName = value;
}
}
}
internal static string MembershipTableName
{
get { return "webpages_Membership"; }
}
internal static string OAuthMembershipTableName
{
get { return "webpages_OAuthMembership"; }
}
private string SafeUserTableName
{
get { return "[" + UserTableName + "]"; }
}
private string SafeUserIdColumn
{
get { return "[" + UserIdColumn + "]"; }
}
private string SafeUserNameColumn
{
get { return "[" + UserNameColumn + "]"; }
}
// represents the User table for the app
public string UserTableName { get; set; }
// represents the User created UserName column, i.e. Email
public string UserNameColumn { get; set; }
// Represents the User created id column, i.e. ID;
// REVIEW: we could get this from the primary key of UserTable in the future
public string UserIdColumn { get; set; }
internal DatabaseConnectionInfo ConnectionInfo { get; set; }
internal bool InitializeCalled { get; set; }
internal override void VerifyInitialized()
{
if (!InitializeCalled)
{
throw new InvalidOperationException(WebDataResources.Security_InitializeMustBeCalledFirst);
}
}
// Inherited from ProviderBase - The "previous provider" we get has already been initialized by the Config system,
// so we shouldn't forward this call
public override void Initialize(string name, NameValueCollection config)
{
if (config == null)
{
throw new ArgumentNullException("config");
}
if (String.IsNullOrEmpty(name))
{
name = "SimpleMembershipProvider";
}
if (String.IsNullOrEmpty(config["description"]))
{
config.Remove("description");
config.Add("description", "Simple Membership Provider");
}
base.Initialize(name, config);
config.Remove("connectionStringName");
config.Remove("enablePasswordRetrieval");
config.Remove("enablePasswordReset");
config.Remove("requiresQuestionAndAnswer");
config.Remove("applicationName");
config.Remove("requiresUniqueEmail");
config.Remove("maxInvalidPasswordAttempts");
config.Remove("passwordAttemptWindow");
config.Remove("passwordFormat");
config.Remove("name");
config.Remove("description");
config.Remove("minRequiredPasswordLength");
config.Remove("minRequiredNonalphanumericCharacters");
config.Remove("passwordStrengthRegularExpression");
config.Remove("hashAlgorithmType");
if (config.Count > 0)
{
string attribUnrecognized = config.GetKey(0);
if (!String.IsNullOrEmpty(attribUnrecognized))
{
throw new ProviderException(String.Format(CultureInfo.CurrentCulture, WebDataResources.SimpleMembership_ProviderUnrecognizedAttribute, attribUnrecognized));
}
}
}
internal static bool CheckTableExists(IDatabase db, string tableName)
{
var query = db.QuerySingle(@"SELECT * from INFORMATION_SCHEMA.TABLES where TABLE_NAME = @0", tableName);
return query != null;
}
internal void CreateTablesIfNeeded()
{
using (var db = ConnectToDatabase())
{
if (!CheckTableExists(db, UserTableName))
{
db.Execute(@"CREATE TABLE " + SafeUserTableName + "(" + SafeUserIdColumn + " int NOT NULL PRIMARY KEY IDENTITY, " + SafeUserNameColumn + " nvarchar(56) NOT NULL UNIQUE)");
}
if (!CheckTableExists(db, OAuthMembershipTableName))
{
db.Execute(@"CREATE TABLE " + OAuthMembershipTableName + "(Provider nvarchar(30) NOT NULL, ProviderUserId nvarchar(100) NOT NULL, UserId int NOT NULL, Primary Key (Provider, ProviderUserId))");
}
if (!CheckTableExists(db, MembershipTableName))
{
db.Execute(@"CREATE TABLE " + MembershipTableName + @" (
UserId int NOT NULL PRIMARY KEY,
CreateDate datetime ,
ConfirmationToken nvarchar(128) ,
IsConfirmed bit DEFAULT 0,
LastPasswordFailureDate datetime ,
PasswordFailuresSinceLastSuccess int NOT NULL DEFAULT 0,
Password nvarchar(128) NOT NULL,
PasswordChangedDate datetime ,
PasswordSalt nvarchar(128) NOT NULL,
PasswordVerificationToken nvarchar(128) ,
PasswordVerificationTokenExpirationDate datetime)");
// TODO: Do we want to add FK constraint to user table too?
// CONSTRAINT fk_UserId FOREIGN KEY (UserId) REFERENCES "+UserTableName+"("+UserIdColumn+"))");
}
}
}
// Not an override ==> Simple Membership MUST be enabled to use this method
public int GetUserId(string userName)
{
VerifyInitialized();
using (var db = ConnectToDatabase())
{
return GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName);
}
}
internal static int GetUserId(IDatabase db, string userTableName, string userNameColumn, string userIdColumn, string userName)
{
var result = db.QueryValue(@"SELECT " + userIdColumn + " FROM " + userTableName + " WHERE (UPPER(" + userNameColumn + ") = @0)", userName.ToUpperInvariant());
if (result != null)
{
return (int)result;
}
return -1;
}
// Inherited from ExtendedMembershipProvider ==> Simple Membership MUST be enabled to use this method
public override int GetUserIdFromPasswordResetToken(string token)
{
VerifyInitialized();
using (var db = ConnectToDatabase())
{
var result = db.QuerySingle(@"SELECT UserId FROM " + MembershipTableName + " WHERE (PasswordVerificationToken = @0)", token);
if (result != null && result[0] != null)
{
return (int)result[0];
}
return -1;
}
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override bool ChangePasswordQuestionAndAnswer(string username, string password, string newPasswordQuestion, string newPasswordAnswer)
{
if (!InitializeCalled)
{
return PreviousProvider.ChangePasswordQuestionAndAnswer(username, password, newPasswordQuestion, newPasswordAnswer);
}
throw new NotSupportedException();
}
///
/// Sets the confirmed flag for the username if it is correct.
///
/// True if the account could be successfully confirmed. False if the username was not found or the confirmation token is invalid.
/// Inherited from ExtendedMembershipProvider ==> Simple Membership MUST be enabled to use this method
public override bool ConfirmAccount(string userName, string accountConfirmationToken)
{
VerifyInitialized();
using (var db = ConnectToDatabase())
{
// We need to compare the token using a case insensitive comparison however it seems tricky to do this uniformly across databases when representing the token as a string.
// Therefore verify the case on the client
var row = db.QuerySingle("SELECT m.[UserId], m.[ConfirmationToken] FROM " + MembershipTableName + " m JOIN " + SafeUserTableName + " u"
+ " ON m.[UserId] = u." + SafeUserIdColumn
+ " WHERE m.[ConfirmationToken] = @0 AND"
+ " u." + SafeUserNameColumn + " = @1", accountConfirmationToken, userName);
if (row == null)
{
return false;
}
int userId = row[0];
string expectedToken = row[1];
if (String.Equals(accountConfirmationToken, expectedToken, StringComparison.Ordinal))
{
int affectedRows = db.Execute("UPDATE " + MembershipTableName + " SET [IsConfirmed] = 1 WHERE [UserId] = @0", userId);
return affectedRows > 0;
}
return false;
}
}
///
/// Sets the confirmed flag for the username if it is correct.
///
/// True if the account could be successfully confirmed. False if the username was not found or the confirmation token is invalid.
/// Inherited from ExtendedMembershipProvider ==> Simple Membership MUST be enabled to use this method.
/// There is a tiny possibility where this method fails to work correctly. Two or more users could be assigned the same token but specified using different cases.
/// A workaround for this would be to use the overload that accepts both the user name and confirmation token.
///
public override bool ConfirmAccount(string accountConfirmationToken)
{
VerifyInitialized();
using (var db = ConnectToDatabase())
{
// We need to compare the token using a case insensitive comparison however it seems tricky to do this uniformly across databases when representing the token as a string.
// Therefore verify the case on the client
var rows = db.Query("SELECT [UserId], [ConfirmationToken] FROM " + MembershipTableName + " WHERE [ConfirmationToken] = @0", accountConfirmationToken)
.Where(r => ((string)r[1]).Equals(accountConfirmationToken, StringComparison.Ordinal))
.ToList();
Debug.Assert(rows.Count < 2, "By virtue of the fact that the ConfirmationToken is random and unique, we can never have two tokens that are identical.");
if (!rows.Any())
{
return false;
}
var row = rows.First();
int userId = row[0];
int affectedRows = db.Execute("UPDATE " + MembershipTableName + " SET [IsConfirmed] = 1 WHERE [UserId] = @0", userId);
return affectedRows > 0;
}
}
internal virtual IDatabase ConnectToDatabase()
{
return new DatabaseWrapper(ConnectionInfo.Connect());
}
// Inherited from ExtendedMembershipProvider ==> Simple Membership MUST be enabled to use this method
public override string CreateAccount(string userName, string password, bool requireConfirmationToken)
{
VerifyInitialized();
if (password.IsEmpty())
{
throw new MembershipCreateUserException(MembershipCreateStatus.InvalidPassword);
}
string hashedPassword = Crypto.HashPassword(password);
if (hashedPassword.Length > 128)
{
throw new MembershipCreateUserException(MembershipCreateStatus.InvalidPassword);
}
if (userName.IsEmpty())
{
throw new MembershipCreateUserException(MembershipCreateStatus.InvalidUserName);
}
using (var db = ConnectToDatabase())
{
// Step 1: Check if the user exists in the Users table
int uid = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName);
if (uid == -1)
{
// User not found
throw new MembershipCreateUserException(MembershipCreateStatus.ProviderError);
}
// Step 2: Check if the user exists in the Membership table: Error if yes.
var result = db.QuerySingle(@"SELECT COUNT(*) FROM [" + MembershipTableName + "] WHERE UserId = @0", uid);
if (result[0] > 0)
{
throw new MembershipCreateUserException(MembershipCreateStatus.DuplicateUserName);
}
// Step 3: Create user in Membership table
string token = null;
object dbtoken = DBNull.Value;
if (requireConfirmationToken)
{
token = GenerateToken();
dbtoken = token;
}
int defaultNumPasswordFailures = 0;
int insert = db.Execute(@"INSERT INTO [" + MembershipTableName + "] (UserId, [Password], PasswordSalt, IsConfirmed, ConfirmationToken, CreateDate, PasswordChangedDate, PasswordFailuresSinceLastSuccess)"
+ " VALUES (@0, @1, @2, @3, @4, @5, @5, @6)", uid, hashedPassword, String.Empty /* salt column is unused */, !requireConfirmationToken, dbtoken, DateTime.UtcNow, defaultNumPasswordFailures);
if (insert != 1)
{
throw new MembershipCreateUserException(MembershipCreateStatus.ProviderError);
}
return token;
}
}
// Inherited from MembershipProvider ==> Forwarded to previous provider if this provider hasn't been initialized
public override MembershipUser CreateUser(string username, string password, string email, string passwordQuestion, string passwordAnswer, bool isApproved, object providerUserKey, out MembershipCreateStatus status)
{
if (!InitializeCalled)
{
return PreviousProvider.CreateUser(username, password, email, passwordQuestion, passwordAnswer, isApproved, providerUserKey, out status);
}
throw new NotSupportedException();
}
private void CreateUserRow(IDatabase db, string userName, IDictionary values)
{
// Make sure user doesn't exist
int userId = GetUserId(db, SafeUserTableName, SafeUserNameColumn, SafeUserIdColumn, userName);
if (userId != -1)
{
throw new MembershipCreateUserException(MembershipCreateStatus.DuplicateUserName);
}
StringBuilder columnString = new StringBuilder();
columnString.Append(SafeUserNameColumn);
StringBuilder argsString = new StringBuilder();
argsString.Append("@0");
List