diff --git a/lib/casserver/authenticators/sql.rb b/lib/casserver/authenticators/sql.rb index e86976b..0707619 100644 --- a/lib/casserver/authenticators/sql.rb +++ b/lib/casserver/authenticators/sql.rb @@ -54,8 +54,9 @@ end class CASServer::Authenticators::SQL < CASServer::Authenticators::Base def self.setup(options) raise CASServer::AuthenticatorError, "Invalid authenticator configuration!" unless options[:database] - - user_model_name = "CASUser_#{options[:auth_index]}" + auth_index = options[:auth_index] + user_table = options[:user_table] || 'users' + user_model_name = "CASUser_#{auth_index}" $LOG.debug "CREATING USER MODEL #{user_model_name}" class_eval %{ @@ -63,24 +64,26 @@ class CASServer::Authenticators::SQL < CASServer::Authenticators::Base end } - @user_model = const_get(user_model_name) - @user_model.establish_connection(options[:database]) + user_model = const_get(user_model_name) + # Register new user module, identified by auth_index. + user_models[auth_index] = user_model + user_model.establish_connection(options[:database]) if ActiveRecord::VERSION::STRING >= '3.2' - @user_model.table_name = (options[:user_table] || 'users') + user_model.table_name = user_table else - @user_model.set_table_name(options[:user_table] || 'users') + user_model.set_table_name(user_table) end - @user_model.inheritance_column = 'no_inheritance_column' if options[:ignore_type_column] + user_model.inheritance_column = 'no_inheritance_column' if options[:ignore_type_column] begin - @user_model.connection + user_model.connection rescue => e $LOG.debug e raise "SQL Authenticator can not connect to database" end end - def self.user_model - @user_model + def self.user_models + @user_models ||= {} end def validate(credentials) @@ -89,15 +92,16 @@ class CASServer::Authenticators::SQL < CASServer::Authenticators::Base log_connection_pool_size user_model.connection_pool.checkin(user_model.connection) + users = matching_users - if matching_users.size > 0 - $LOG.warn("#{self.class}: Multiple matches found for user #{@username.inspect}") if matching_users.size > 1 + if users.size > 0 + $LOG.warn("#{self.class}: Multiple matches found for user #{@username.inspect}") if users.size > 1 unless @options[:extra_attributes].blank? - if matching_users.size > 1 + if users.size > 1 $LOG.warn("#{self.class}: Unable to extract extra_attributes because multiple matches were found for #{@username.inspect}") else - user = matching_users.first + user = users.first extract_extra(user) log_extra @@ -113,7 +117,11 @@ class CASServer::Authenticators::SQL < CASServer::Authenticators::Base protected def user_model - self.class.user_model + self.class.user_models[auth_index] + end + + def auth_index + @options[:auth_index] end def username_column diff --git a/spec/casserver/authenticators/sql_spec.rb b/spec/casserver/authenticators/sql_spec.rb new file mode 100644 index 0000000..c935768 --- /dev/null +++ b/spec/casserver/authenticators/sql_spec.rb @@ -0,0 +1,116 @@ +require 'spec_helper' + +describe CASServer::Authenticators::SQL do + let(:options) do + { + auth_index: 0, + user_table: 'users', + username_column: 'username', + password_column: 'password', + database: { + adapter: 'mysql2', + database: 'casserver', + username: 'root', + password: 'password', + host: 'localhost' + } + } + end + let(:connection) { double('Connection', run_callbacks: nil) } + let(:connection_pool) { double('ConnectionPool', + connections: [connection], + checkin: nil) } + + before do + load_server('default_config') if $LOG.nil? # ensure logger is present + ActiveRecord::Base.stub(:establish_connection) + ActiveRecord::Base.stub(:connection).and_return(connection) + ActiveRecord::Base.stub(:connection_pool).and_return(connection_pool) + CASServer::Authenticators::SQL.setup(options) + end + + describe '#validate' do + let(:auth) { CASServer::Authenticators::SQL.new } + let(:username) { 'dave' } + let(:password) { 'secret' } + let(:user_model) { CASServer::Authenticators::SQL.user_models[0] } + + before do + auth.configure(HashWithIndifferentAccess.new(options)) + end + + context 'when credentials match a user in the database' do + it 'returns true' do + conditions = ['username = ? AND password = ?', username, password] + user_model.should_receive(:find).with(:all, conditions: conditions) + .and_return([:user]) + credentials = { + username: username, + password: password + } + expect(auth.validate(credentials)).to be true + end + end + + context 'when credentials do not match a user in the database' do + it 'returns false' do + conditions = ['username = ? AND password = ?', username, password] + user_model.should_receive(:find).with(:all, conditions: conditions) + .and_return([]) + credentials = { + username: username, + password: password + } + expect(auth.validate(credentials)).to be false + end + end + + context 'when many SQL authenticators have been setup' do + let(:alt_options) do + { + auth_index: 1, + user_table: 'users', + username_column: 'username', + password_column: 'password', + database: { + adapter: 'mysql2', + database: 'casserver', + username: 'root', + password: 'password', + host: 'localhost' + } + } + end + + before do + CASServer::Authenticators::SQL.setup(alt_options) + end + + it 'chooses the correct user model based upon auth_index' do + # Original authenticator + conditions = ['username = ? AND password = ?', username, password] + user_model.should_receive(:find).with(:all, conditions: conditions) + .and_return([:user]) + credentials = { + username: username, + password: password + } + expect(auth.validate(credentials)).to be true + + # Alternate authenticator, different credentials, different user model + alt_auth = CASServer::Authenticators::SQL.new + alt_user_model = CASServer::Authenticators::SQL.user_models[1] + alt_username = 'dan' + conditions = ['username = ? AND password = ?', alt_username, password] + alt_user_model.should_receive(:find).with(:all, conditions: conditions) + .and_return([:user]) + alt_credentials = { + username: alt_username, + password: password + } + alt_auth.configure(HashWithIndifferentAccess.new(alt_options)) + expect(alt_auth.validate(alt_credentials)).to be true + end + end + end +end