diff --git a/AirlineTestDB.bak b/AirlineTestDB.bak new file mode 100644 index 0000000..b07e9fe Binary files /dev/null and b/AirlineTestDB.bak differ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..bc329d2 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a +Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us +the rights to use your contribution. For details, visit https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether you need to provide +a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions +provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or +contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + diff --git a/LICENSE b/LICENSE index 2107107..e39502a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,25 @@ - MIT License +------------------------------------------- START OF LICENSE ----------------------------------------- +sqlmlutils - Copyright (c) Microsoft Corporation. All rights reserved. +MIT License - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: +Copyright (c) Microsoft Corporation. All rights reserved. - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +----------------------------------------------- END OF LICENSE ------------------------------------------ diff --git a/Python/LICENSE b/Python/LICENSE new file mode 100644 index 0000000..e39502a --- /dev/null +++ b/Python/LICENSE @@ -0,0 +1,25 @@ +------------------------------------------- START OF LICENSE ----------------------------------------- +sqlmlutils + +MIT License + +Copyright (c) Microsoft Corporation. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +----------------------------------------------- END OF LICENSE ------------------------------------------ diff --git a/Python/MANIFEST b/Python/MANIFEST new file mode 100644 index 0000000..1d2996c --- /dev/null +++ b/Python/MANIFEST @@ -0,0 +1,19 @@ +# file GENERATED by distutils, do NOT edit +setup.py +sqlmlutils\__init__.py +sqlmlutils\connectioninfo.py +sqlmlutils\sqlbuilder.py +sqlmlutils\sqlpythonexecutor.py +sqlmlutils\sqlqueryexecutor.py +sqlmlutils\storedprocedure.py +sqlmlutils/packagemanagement\__init__.py +sqlmlutils/packagemanagement\dependencyresolver.py +sqlmlutils/packagemanagement\download_script.py +sqlmlutils/packagemanagement\messages.py +sqlmlutils/packagemanagement\outputcapture.py +sqlmlutils/packagemanagement\packagesqlbuilder.py +sqlmlutils/packagemanagement\pipdownloader.py +sqlmlutils/packagemanagement\pkgutils.py +sqlmlutils/packagemanagement\scope.py +sqlmlutils/packagemanagement\servermethods.py +sqlmlutils/packagemanagement\sqlpackagemanager.py diff --git a/Python/README.md b/Python/README.md new file mode 100644 index 0000000..65a99c4 --- /dev/null +++ b/Python/README.md @@ -0,0 +1,220 @@ +# sqlmlutils + +sqlmlutils is a python package to help execute Python code on a SQL Server machine. It is built to work with ML Services for SQL Server. + +# Installation + +Run +``` +python.exe -m pip install dist/sqlmlutils-0.5.0.zip --upgrade +``` +OR +To build a new package file and install, run +``` +.\buildandinstall.cmd +``` + +Note: If you encounter errors installing the pymssql dependency and your client is a Windows machine, consider +installing the .whl file at the below link (download the file for your Python version and run pip install): +https://www.lfd.uci.edu/~gohlke/pythonlibs/#pymssql + +# Getting started + +Shown below are the important functions sqlmlutils provides: +```python +execute_function_in_sql # Execute a python function inside the SQL database +execute_script_in_sql # Execute a python script inside the SQL database +execute_sql_query # Execute a sql query in the database and return the resultant table + +create_sproc_from_function # Create a stored procedure based on a Python function inside the SQL database +create_sproc_from_script # Create a stored procedure based on a Python script inside the SQL database +check_sproc # Check whether a stored procedure exists in the SQL database +drop_sproc # Drop a stored procedure from the SQL database +execute_sproc # Execute a stored procedure in the SQL database + +install_package # Install a Python package on the SQL database +remove_package # Remove a Python package from the SQL database +list # Enumerate packages that are installed on the SQL database +``` + +# Examples + +### Execute in SQL +##### Execute a python function in database + +```python +import sqlmlutils + +def foo(): + return "bar" + +sqlpy = sqlmlutils.SQLPythonExecutor(sqlmlutils.ConnectionInfo(server="localhost", database="master")) +result = sqlpy.execute_function_in_sql(foo) +assert result == "bar" +``` + +##### Generate a scatter plot without the data leaving the machine + +```python +import sqlmlutils +from PIL import Image + + +def scatter_plot(input_df, x_col, y_col): + import matplotlib.pyplot as plt + import io + + title = x_col + " vs. " + y_col + + plt.scatter(input_df[x_col], input_df[y_col]) + plt.xlabel(x_col) + plt.ylabel(y_col) + plt.title(title) + + # Save scatter plot image as a png + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + + # Returns the bytes of the png to the client + return buf + + +sqlpy = sqlmlutils.SQLPythonExecutor(sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")) + +sql_query = "select top 100 * from airline5000" +plot_data = sqlpy.execute_function_in_sql(func=scatter_plot, input_data_query=sql_query, + x_col="ArrDelay", y_col="CRSDepTime") +im = Image.open(plot_data) +im.show() +``` + +##### Perform linear regression on data stored in SQL Server without the data leaving the machine + +You can use the AirlineTestDB (supplied as a .bak file above) to run these examples. + +```python +import sqlmlutils + + +def linear_regression(input_df, x_col, y_col): + from sklearn import linear_model + + X = input_df[[x_col]] + y = input_df[y_col] + + lr = linear_model.LinearRegression() + lr.fit(X, y) + + return lr + + +sqlpy = sqlmlutils.SQLPythonExecutor(sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")) +sql_query = "select top 1000 CRSDepTime, CRSArrTime from airline5000" +regression_model = sqlpy.execute_function_in_sql(linear_regression, input_data_query=sql_query, + x_col="CRSDepTime", y_col="CRSArrTime") +print(regression_model) +print(regression_model.coef_) +``` + +##### Execute a SQL Query from Python + +```python +import sqlmlutils +import pytest + +sqlpy = sqlmlutils.SQLPythonExecutor(sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")) +sql_query = "select top 10 * from airline5000" +data_table = sqlpy.execute_sql_query(sql_query) +assert len(data_table.columns) == 30 +assert len(data_table) == 10 +``` + +### Stored Procedure +##### Create and call a T-SQL stored procedure based on a Python function + +```python +import sqlmlutils +import pytest + +def principal_components(input_table: str, output_table: str): + import sqlalchemy + from urllib import parse + import pandas as pd + from sklearn.decomposition import PCA + + # Internal ODBC connection string used by process executing inside SQL Server + connection_string = "Driver=SQL Server;Server=localhost;Database=AirlineTestDB;Trusted_Connection=Yes;" + engine = sqlalchemy.create_engine("mssql+pyodbc:///?odbc_connect={}".format(parse.quote_plus(connection_string))) + + input_df = pd.read_sql("select top 200 ArrDelay, CRSDepTime from {}".format(input_table), engine).dropna() + + + pca = PCA(n_components=2) + components = pca.fit_transform(input_df) + + output_df = pd.DataFrame(components) + output_df.to_sql(output_table, engine, if_exists="replace") + + +connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB") + +input_table = "airline5000" +output_table = "AirlineDemoPrincipalComponents" + +sp_name = "SavePrincipalComponents" + +sqlpy = sqlmlutils.SQLPythonExecutor(connection) + +if sqlpy.check_sproc(sp_name): + sqlpy.drop_sproc(sp_name) + +sqlpy.create_sproc_from_function(sp_name, principal_components) + +# You can check the stored procedure exists in the db with this: +assert sqlpy.check_sproc(sp_name) + +sqlpy.execute_sproc(sp_name, input_table=input_table, output_table=output_table) + +sqlpy.drop_sproc(sp_name) +assert not sqlpy.check_sproc(sp_name) +``` + +### Package Management +##### Install and remove packages from SQL Server + +```python +import sqlmlutils + +connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB") +sqlpy = sqlmlutils.SQLPythonExecutor(connection) +pkgmanager = sqlmlutils.SQLPackageManager(connection) + +def use_tensorflow(): + import tensorflow as tf + node1 = tf.constant(3.0, tf.float32) + return str(node1.dtype) + +pkgmanager.install("tensorflow") +val = sqlpy.execute_function_in_sql(use_tensorflow) + +pkgmanager.uninstall("tensorflow") +``` + + +# Notes for Developers + +### Running the tests + +1. Make sure a SQL Server with an updated ML Services Python is running on localhost. +2. Restore the AirlineTestDB from the .bak file in this repo +3. Make sure Trusted (Windows) authentication works for connecting to the database +4. Setup a user with db_owner role with uid: "Tester" and password "FakeTesterPwd" + +### Notable TODOs and open issues + +1. The pymssql library is hard to install. Users need to install the .whl files from the link above, not +the .whl files currently hosted in PyPI. Because of this, we should consider moving to use pyodbc. +2. Testing from a Linux client has not been performed. +3. The way we get dependencies of a package to install is sort of hacky (parsing pip output) +4. Output Parameter execution currently does not work - can potentially use MSSQLStoredProcedure binding diff --git a/Python/buildandinstall.cmd b/Python/buildandinstall.cmd new file mode 100644 index 0000000..8ef7c0d --- /dev/null +++ b/Python/buildandinstall.cmd @@ -0,0 +1,2 @@ +python.exe setup.py sdist +python.exe -m pip install dist\sqlmlutils-0.5.0.zip --upgrade diff --git a/Python/dist/sqlmlutils-0.5.0.zip b/Python/dist/sqlmlutils-0.5.0.zip new file mode 100644 index 0000000..ccd1c9b Binary files /dev/null and b/Python/dist/sqlmlutils-0.5.0.zip differ diff --git a/Python/setup.py b/Python/setup.py new file mode 100644 index 0000000..1c87cd7 --- /dev/null +++ b/Python/setup.py @@ -0,0 +1,25 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from distutils.core import setup + +setup( + name='sqlmlutils', + packages=['sqlmlutils', 'sqlmlutils/packagemanagement'], + version='0.5.0', + url='https://github.com/Microsoft/sqlmlutils', + license='MIT License', + description='A client side package for working with SQL Machine Learning Python Services. ' + 'sqlmlutils enables easy package installation and remote code execution on your SQL Server machine.', + author='Microsoft', + author_email='joz@microsoft.com', + install_requires=[ + 'pip', + 'pymssql', + 'dill', + 'pkginfo', + 'requirements-parser', + 'pandas' + ], + python_requires='>=3.5' +) diff --git a/Python/sqlmlutils/__init__.py b/Python/sqlmlutils/__init__.py new file mode 100644 index 0000000..526105f --- /dev/null +++ b/Python/sqlmlutils/__init__.py @@ -0,0 +1,6 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from .connectioninfo import ConnectionInfo +from .sqlpythonexecutor import SQLPythonExecutor +from .packagemanagement.sqlpackagemanager import SQLPackageManager diff --git a/Python/sqlmlutils/connectioninfo.py b/Python/sqlmlutils/connectioninfo.py new file mode 100644 index 0000000..d9e7f1b --- /dev/null +++ b/Python/sqlmlutils/connectioninfo.py @@ -0,0 +1,55 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +class ConnectionInfo: + """Information needed to connect to SQL Server. + + """ + + def __init__(self, driver: str = "SQL Server", server: str = "localhost", database: str = "master", + uid: str = "", pwd: str = ""): + """ + :param driver: Driver to use to connect to SQL Server. + :param server: SQL Server hostname or a specific instance to connect to. + :param database: Database to connect to. + :param uid: uid to connect with. If not specified, utilizes trusted authentication. + :param pwd: pwd to connect with. If uid is not specified, pwd is ignored; uses trusted auth instead + + >>> from sqlmlutils import ConnectionInfo + >>> connection = ConnectionInfo(server="ServerName", database="DatabaseName", uid="Uid", pwd="Pwd") + """ + self._driver = driver + self._server = server + self._database = database + self._uid = uid + self._pwd = pwd + + @property + def driver(self): + return self._driver + + @property + def server(self): + return self._server + + @property + def database(self): + return self._database + + @property + def uid(self): + return self._uid + + @property + def pwd(self): + return self._pwd + + @property + def connection_string(self): + return "Driver={{driver}};Server={server};Database={database};{auth};".format( + driver=self._driver, + server=self._server, + database=self._database, + auth="Trusted_Connection=Yes" if self._uid == "" else + "uid={uid};pwd={pwd}".format(uid=self._uid, pwd=self._pwd) + ) diff --git a/Python/sqlmlutils/packagemanagement/__init__.py b/Python/sqlmlutils/packagemanagement/__init__.py new file mode 100644 index 0000000..d9f035d --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/__init__.py @@ -0,0 +1 @@ +from .sqlpackagemanager import SQLPackageManager diff --git a/Python/sqlmlutils/packagemanagement/dependencyresolver.py b/Python/sqlmlutils/packagemanagement/dependencyresolver.py new file mode 100644 index 0000000..c5b110b --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/dependencyresolver.py @@ -0,0 +1,60 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import operator + +from distutils.version import LooseVersion + + +class DependencyResolver: + + def __init__(self, server_packages, target_package): + self._server_packages = server_packages + self._target_package = target_package + + def requirement_met(self, upgrade: bool, version: str = None) -> bool: + exists = self._package_exists_on_server() + return exists and (not upgrade or + (version is not None and self.get_target_server_version() != "" and + LooseVersion(self.get_target_server_version()) >= LooseVersion(version))) + + def get_target_server_version(self): + for package in self._server_packages: + if package[0].lower() == self._target_package.lower(): + return package[1] + return "" + + def get_required_installs(self, target_requirements): + required_packages = [] + for requirement in target_requirements: + reqmet = any(package[0] == requirement.name for package in self._server_packages) + + for spec in requirement.specs: + reqmet = reqmet & self._check_if_installed_package_meets_spec( + self._server_packages, requirement.name, spec) + + if not reqmet or requirement.name == self._target_package: + required_packages.append(self.clean_requirement_name(requirement.name)) + return required_packages + + def _package_exists_on_server(self): + return any([serverpkg[0].lower() == self._target_package.lower() for serverpkg in self._server_packages]) + + @staticmethod + def clean_requirement_name(reqname: str): + return reqname.replace("-", "_") + + @staticmethod + def _check_if_installed_package_meets_spec(package_tuples, name, spec): + op_str = spec[0] + req_version = spec[1] + + installed_package_name_and_version = [package for package in package_tuples if package[0] == name] + if not installed_package_name_and_version: + return False + + installed_package_name_and_version = installed_package_name_and_version[0] + installed_version = installed_package_name_and_version[1] + + operator_map = {'>': 'gt', '>=': 'ge', '<': 'lt', '==': 'eq', '<=': 'le', '!=': 'ne'} + return getattr(operator, operator_map[op_str])(LooseVersion(installed_version), LooseVersion(req_version)) diff --git a/Python/sqlmlutils/packagemanagement/download_script.py b/Python/sqlmlutils/packagemanagement/download_script.py new file mode 100644 index 0000000..c843518 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/download_script.py @@ -0,0 +1,26 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from distutils.version import LooseVersion +import pip +import warnings +import sys + +pipversion = LooseVersion(pip.__version__ ) +if pipversion > LooseVersion("10"): + from pip._internal import pep425tags + from pip._internal import main as pipmain +else: + if pipversion < LooseVersion("8.1.2"): + warnings.warn("Pip version less than 8.1.2 not supported.", Warning) + from pip import pep425tags + from pip import main as pipmain + +# Monkey patch the pip version information with server information +pep425tags.get_impl_version_info = lambda: eval(sys.argv[1]) +pep425tags.get_abbr_impl = lambda: sys.argv[2] +pep425tags.get_abi_tag = lambda: sys.argv[3] +pep425tags.get_platform = lambda: sys.argv[4] + +# Call pipmain with the download request +pipmain(list(map(str.strip, sys.argv[5].split(",")))) diff --git a/Python/sqlmlutils/packagemanagement/messages.py b/Python/sqlmlutils/packagemanagement/messages.py new file mode 100644 index 0000000..61382b7 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/messages.py @@ -0,0 +1,11 @@ +def no_upgrade(pkgname: str, serverversion: str, pkgversion: str = ""): + return """ +Package {pkgname} exists on server. Set upgrade to True to force upgrade.".format(pkgname)) +The version of {pkgname} you are trying to install is {pkgversion}. +The version installed on the server is {serverversion} + """.format(pkgname=pkgname, pkgversion=pkgversion, serverversion=serverversion) + + +def install(pkgname: str, version: str, targetpackage: bool): + target = "target package" if targetpackage else "required dependency" + return "Installing {} {} version {}".format(target, pkgname, version) diff --git a/Python/sqlmlutils/packagemanagement/outputcapture.py b/Python/sqlmlutils/packagemanagement/outputcapture.py new file mode 100644 index 0000000..b9eee0e --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/outputcapture.py @@ -0,0 +1,9 @@ +import sys +import io + + +class OutputCapture(io.StringIO): + + def write(self, txt): + sys.__stdout__.write(txt) + super().write(txt) diff --git a/Python/sqlmlutils/packagemanagement/packagesqlbuilder.py b/Python/sqlmlutils/packagemanagement/packagesqlbuilder.py new file mode 100644 index 0000000..7cbe795 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/packagesqlbuilder.py @@ -0,0 +1,130 @@ +from sqlmlutils.sqlbuilder import SQLBuilder +from sqlmlutils.packagemanagement.scope import Scope + + +class CreateLibraryBuilder(SQLBuilder): + + def __init__(self, pkg_name: str, pkg_filename: str, scope: Scope): + self._name = clean_library_name(pkg_name) + self._filename = pkg_filename + self._has_params = True + self._scope = scope + + @property + def params(self): + with open(self._filename, "rb") as f: + pkgdatastr = "0x" + f.read().hex() + + installcheckscript = """ +import os +import re +_ENV_NAME_USER_PATH = "MRS_EXTLIB_USER_PATH" +_ENV_NAME_SHARED_PATH = "MRS_EXTLIB_SHARED_PATH" + + +def _is_dist_info_file(name, file): + return re.match(name + r'-.*egg', file) or re.match(name + r'-.*dist-info', file) + + +def _is_package_match(package_name, file): + package_name = package_name.lower() + file = file.lower() + return file == package_name or file == package_name + ".py" or \ + _is_dist_info_file(package_name, file) or \ + ("-" in package_name and + (package_name.split("-")[0] == file or _is_dist_info_file(package_name.replace("-", "_"), file))) + +def package_files_in_scope(scope='private'): + envdir = _ENV_NAME_SHARED_PATH if scope == 'public' or os.environ.get(_ENV_NAME_USER_PATH, "") == "" \ + else _ENV_NAME_USER_PATH + path = os.environ.get(envdir, "") + if os.path.isdir(path): + return os.listdir(path) + return [] + +def package_exists_in_scope(sql_package_name: str, scope=None) -> bool: + if scope is None: + # default to user path for every user but DBOs + scope = 'public' if (os.environ.get(_ENV_NAME_USER_PATH, "") == "") else 'private' + package_files = package_files_in_scope(scope) + return any([_is_package_match(sql_package_name, package_file) for package_file in package_files]) + + +assert package_exists_in_scope("{sqlpkgname}", "{scopestr}") +""".format(sqlpkgname=self._name, scopestr=self._scope._name) + + return pkgdatastr, installcheckscript + + @property + def base_script(self) -> str: + return """ +-- Wrap this in a transaction +DECLARE @TransactionName varchar(30) = 'SqlPackageTransaction'; +BEGIN TRAN @TransactionName + +-- Drop the library if it exists +BEGIN TRY +DROP EXTERNAL LIBRARY [{sqlpkgname}] {authorization} +END TRY +BEGIN CATCH +END CATCH + +-- Parameter bind the package data +DECLARE @content varbinary(MAX) = convert(varbinary(MAX), %s, 1); + +-- Create the library +CREATE EXTERNAL LIBRARY [{sqlpkgname}] {authorization} +FROM (CONTENT = @content) WITH (LANGUAGE = 'Python'); + +-- Dummy SPEES +{dummy_spees} + +-- Check to make sure the package was installed +BEGIN TRY + exec sp_execute_external_script + @language = N'Python', + @script = %s + -- Installation succeeded, commit the transaction + COMMIT TRAN @TransactionName + print('Package successfully installed.') +END TRY +BEGIN CATCH + -- Installation failed, rollback the transaction + ROLLBACK TRAN @TransactionName + print('Package installation failed.'); + THROW; +END CATCH +""".format(sqlpkgname=self._name, + authorization=_get_authorization(self._scope), + dummy_spees=_get_dummy_spees()) + + +class DropLibraryBuilder(SQLBuilder): + + def __init__(self, sql_package_name: str, scope: Scope): + self._name = clean_library_name(sql_package_name) + self._scope = scope + + @property + def base_script(self) -> str: + return """ +DROP EXTERNAL LIBRARY [{}] {authorization} + +{dummy_spees} +""".format(self._name, authorization=_get_authorization(self._scope), dummy_spees=_get_dummy_spees()) + + +def clean_library_name(pkgname: str): + return pkgname.replace("-", "_").lower() + + +def _get_authorization(scope: Scope) -> str: + return "AUTHORIZATION dbo" if scope == Scope.public_scope() else "" + + +def _get_dummy_spees() -> str: + return """ +exec sp_execute_external_script +@language = N'Python', +@script = N'' +""" diff --git a/Python/sqlmlutils/packagemanagement/pipdownloader.py b/Python/sqlmlutils/packagemanagement/pipdownloader.py new file mode 100644 index 0000000..c4424b9 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/pipdownloader.py @@ -0,0 +1,82 @@ +import re +import requirements +import subprocess +import os + +from sqlmlutils import ConnectionInfo, SQLPythonExecutor +from sqlmlutils.packagemanagement import servermethods + +class PipDownloader: + + def __init__(self, connection: ConnectionInfo, downloaddir: str, targetpackage: str): + self._connection = connection + self._downloaddir = downloaddir + self._targetpackage = targetpackage + server_info = SQLPythonExecutor(connection).execute_function_in_sql(servermethods.get_server_info) + globals().update(server_info) + + def download(self): + return self._download(True) + + def download_single(self) -> str: + _, pkgsdownloaded = self._download(False) + return pkgsdownloaded[0] + + def _download(self, withdependencies): + # This command directs pip to download the target package, as well as all of its dependencies into + # temporary_directory. + commands = ["download", self._targetpackage, "--destination-dir", self._downloaddir, "--no-cache-dir"] + if not withdependencies: + commands.append("--no-dependencies") + + output, error = self._run_in_new_process(commands) + + pkgreqs = self._get_reqs_from_output(output) + + packagesdownloaded = [os.path.join(self._downloaddir, f) for f in os.listdir(self._downloaddir) + if os.path.isfile(os.path.join(self._downloaddir, f))] + + return pkgreqs, packagesdownloaded + + def _run_in_new_process(self, commands): + # We get the package requirements based on the print output of pip, which is stable across version 8-10. + # TODO: get requirements in a more robust way (either through using pip internal code or rolling our own) + download_script = os.path.join((os.path.dirname(os.path.realpath(__file__))), "download_script.py") + args = ["python", download_script, + str(_patch_get_impl_version_info()), str(_patch_get_abbr_impl()), + str(_patch_get_abi_tag()), str(_patch_get_platform()), + ",".join(str(x) for x in commands)] + + with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc: + output = proc.stdout.read() + error = proc.stderr.read() + + return output.decode(), error.decode() + + @staticmethod + def _get_reqs_from_output(pipoutput: str): + # TODO: get requirements in a more robust way (either through using pip internal code or rolling our own) + collectinglines = [line for line in pipoutput.splitlines() if "Collecting" in line] + + f = lambda unclean: \ + re.sub(r'\(.*\)', "", unclean.replace("Collecting ", "").strip()) + + reqstr = "\n".join([f(line) for line in collectinglines]) + return list(requirements.parse(reqstr)) + + +def _patch_get_impl_version_info(): + return globals()["impl_version_info"] + + +def _patch_get_abbr_impl(): + return globals()["abbr_impl"] + + +def _patch_get_abi_tag(): + return globals()["abi_tag"] + + +def _patch_get_platform(): + return globals()["platform"] + diff --git a/Python/sqlmlutils/packagemanagement/pkgutils.py b/Python/sqlmlutils/packagemanagement/pkgutils.py new file mode 100644 index 0000000..5871cd4 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/pkgutils.py @@ -0,0 +1,29 @@ +import pkginfo +import os +import re + + +def _get_pkginfo(filename: str): + try: + if ".whl" in filename: + return pkginfo.Wheel(filename) + else: + return pkginfo.SDist(filename) + except Exception: + return None + + +def get_package_name_from_file(filename: str) -> str: + pkg = _get_pkginfo(filename) + if pkg is not None and pkg.name is not None: + return pkg.name + name = os.path.splitext(os.path.basename(filename))[0] + return re.sub(r"\-[0-9].*", "", name) + + +def get_package_version_from_file(filename: str): + pkg = _get_pkginfo(filename) + if pkg is not None and pkg.version is not None: + return pkg.version + return None + diff --git a/Python/sqlmlutils/packagemanagement/scope.py b/Python/sqlmlutils/packagemanagement/scope.py new file mode 100644 index 0000000..311ae87 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/scope.py @@ -0,0 +1,20 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +class Scope: + + def __init__(self, name: str): + self._name = name + + def __eq__(self, other): + return self._name == other._name + + @staticmethod + def public_scope(): + return Scope("public") + + @staticmethod + def private_scope(): + return Scope("private") + + diff --git a/Python/sqlmlutils/packagemanagement/servermethods.py b/Python/sqlmlutils/packagemanagement/servermethods.py new file mode 100644 index 0000000..86f9f94 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/servermethods.py @@ -0,0 +1,78 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from sqlmlutils.packagemanagement.scope import Scope +import os +import re + +_ENV_NAME_USER_PATH = "MRS_EXTLIB_USER_PATH" +_ENV_NAME_SHARED_PATH = "MRS_EXTLIB_SHARED_PATH" + + +def show_installed_packages(): + from distutils.version import LooseVersion + import pip + if LooseVersion(pip.__version__) > LooseVersion("10"): + from pip._internal.operations import freeze + else: + from pip.operations import freeze + + packages = [] + for package in list(freeze.freeze()): + val = package.split("==") + name = val[0] + version = val[1] + packages.append((name, version)) + return packages + + +def get_server_info(): + from distutils.version import LooseVersion + import pip + if LooseVersion(pip.__version__) > LooseVersion("10"): + from pip._internal import pep425tags + else: + from pip import pep425tags + return { + "impl_version_info": pep425tags.get_impl_version_info(), + "abbr_impl": pep425tags.get_abbr_impl(), + "abi_tag": pep425tags.get_abi_tag(), + "platform": pep425tags.get_platform() + } + + +def check_package_install_success(sql_package_name: str) -> bool: + return package_exists_in_scope(sql_package_name) + + +def package_files_in_scope(scope=Scope.private_scope()): + envdir = _ENV_NAME_SHARED_PATH if scope == Scope.public_scope() or os.environ.get(_ENV_NAME_USER_PATH, "") == "" \ + else _ENV_NAME_USER_PATH + path = os.environ.get(envdir, "") + if os.path.isdir(path): + return os.listdir(path) + return [] + + +def package_exists_in_scope(sql_package_name: str, scope=None) -> bool: + if scope is None: + # default to user path for every user but DBOs + scope = Scope.public_scope() if (os.environ.get(_ENV_NAME_USER_PATH, "") == "") else Scope.private_scope() + package_files = package_files_in_scope(scope) + return any([_is_package_match(sql_package_name, package_file) for package_file in package_files]) + + +def _is_dist_info_file(name, file): + return re.match(name + r'-.*egg', file) or re.match(name + r'-.*dist-info', file) + + +def _is_package_match(package_name, file): + package_name = package_name.lower() + file = file.lower() + return file == package_name or file == package_name + ".py" or \ + _is_dist_info_file(package_name, file) or \ + ("-" in package_name and + (package_name.split("-")[0] == file or _is_dist_info_file(package_name.replace("-", "_"), file))) + + + diff --git a/Python/sqlmlutils/packagemanagement/sqlpackagemanager.py b/Python/sqlmlutils/packagemanagement/sqlpackagemanager.py new file mode 100644 index 0000000..eff5367 --- /dev/null +++ b/Python/sqlmlutils/packagemanagement/sqlpackagemanager.py @@ -0,0 +1,176 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +import tempfile +import zipfile +import warnings + +from sqlmlutils import ConnectionInfo, SQLPythonExecutor +from sqlmlutils.sqlqueryexecutor import execute_query, SQLTransaction +from sqlmlutils.packagemanagement.packagesqlbuilder import clean_library_name +from sqlmlutils.packagemanagement import servermethods +from sqlmlutils.sqlqueryexecutor import SQLQueryExecutor +from sqlmlutils.packagemanagement.dependencyresolver import DependencyResolver +from sqlmlutils.packagemanagement.pipdownloader import PipDownloader +from sqlmlutils.packagemanagement.scope import Scope +from sqlmlutils.packagemanagement import messages +from sqlmlutils.packagemanagement.pkgutils import get_package_name_from_file, get_package_version_from_file +from sqlmlutils.packagemanagement.packagesqlbuilder import CreateLibraryBuilder, DropLibraryBuilder + + +class SQLPackageManager: + + def __init__(self, connection_info: ConnectionInfo): + self._connection_info = connection_info + self._pyexecutor = SQLPythonExecutor(connection_info) + + def install(self, + package: str, + upgrade: bool = False, + version: str = None, + install_dependencies: bool = True, + scope: Scope = Scope.private_scope()): + """Install Python package into a SQL Server Python Services environment using pip. + + :param package: Package name to install on the SQL Server. Can also be a filename. + :param upgrade: If True, will update the package if it exists on the specified SQL Server. + If False, will not try to update an existing package. + :param version: Not yet supported. Package version to install. If not specified, + current stable version for server environment as determined by PyPi/Anaconda repos. + :param install_dependencies: If True, installs required dependencies of package (similar to how default + pip install or conda install works). False not yet supported. + :param scope: Specifies whether to install packages into private or public scope. Default is private scope. + This installs packages into a private path for the SQL principal you connect as. If your principal has the + db_owner role, you can also specify scope as public. This will install packages into a public path for all + users. Note: if you connect as dbo, you can only install packages into the public path. + + >>> from sqlmlutils import ConnectionInfo, SQLPythonExecutor, SQLPackageManager + >>> connection = ConnectionInfo(server="localhost", database="AirlineTestsDB") + >>> pyexecutor = SQLPythonExecutor(connection) + >>> pkgmanager = SQLPackageManager(connection) + >>> + >>> def use_tensorflow(): + >>> import tensorflow as tf + >>> node1 = tf.constant(3.0, tf.float32) + >>> return str(node1.dtype) + >>> + >>> pkgmanager.install("tensorflow") + >>> ret = pyexecutor.execute_function_in_sql(connection=connection, use_tensorflow) + >>> pkgmanager.uninstall("tensorflow") + + """ + if not install_dependencies: + raise ValueError("Dependencies will always be installed - " + "single package install without dependencies not yet supported.") + + if os.path.isfile(package): + self._install_from_file(package, scope, upgrade) + else: + self._install_from_pypi(package, upgrade, version, install_dependencies, scope) + + def uninstall(self, package_name: str, scope: Scope = Scope.private_scope()): + """Remove Python package from a SQL Server Python environment. + + :param package_name: Package name to remove on the SQL Server. + :param scope: Specifies whether to uninstall packages from private or public scope. Default is private scope. + This uninstalls packages from a private path for the SQL principal you connect as. If your principal has the + db_owner role, you can also specify scope as public. This will uninstall packages from a public path for all + users. Note: if you connect as dbo, you can only uninstall packages from the public path. + """ + print("Uninstalling " + package_name + "only, not dependencies") + self._drop_sql_package(package_name, scope) + + def list(self): + """List packages installed on server, similar to output of pip freeze. + + :return: List of tuples, each tuple[0] is package name and tuple[1] is package version. + """ + return self._pyexecutor.execute_function_in_sql(servermethods.show_installed_packages) + + def _drop_sql_package(self, sql_package_name: str, scope: Scope): + builder = DropLibraryBuilder(sql_package_name=sql_package_name, scope=scope) + execute_query(builder, self._connection_info) + + # TODO: Support not dependencies + def _install_from_pypi(self, + target_package: str, + upgrade: bool = False, + version: str = None, + install_dependencies: bool = True, + scope: Scope = Scope.private_scope()): + + if not install_dependencies: + raise ValueError("Dependencies will always be installed - " + "single package install without dependencies not yet supported.") + + if version is not None: + target_package = target_package + "==" + version + + with tempfile.TemporaryDirectory() as temporary_directory: + pipdownloader = PipDownloader(self._connection_info, temporary_directory, target_package) + target_package_file = pipdownloader.download_single() + self._install_from_file(target_package_file, scope, upgrade) + + def _install_from_file(self, target_package_file: str, scope: Scope, upgrade: bool = False): + name = get_package_name_from_file(target_package_file) + version = get_package_version_from_file(target_package_file) + + resolver = DependencyResolver(self.list(), name) + if resolver.requirement_met(upgrade, version): + serverversion = resolver.get_target_server_version() + print(messages.no_upgrade(name, serverversion, version)) + return + + # Download requirements from PyPI + with tempfile.TemporaryDirectory() as temporary_directory: + pipdownloader = PipDownloader(self._connection_info, temporary_directory, target_package_file) + + # For now, we download all target package dependencies from PyPI. + target_package_requirements, requirements_downloaded = pipdownloader.download() + + # Resolve which package dependencies need to be installed or upgraded on server. + required_installs = resolver.get_required_installs(target_package_requirements) + dependencies_to_install = self._get_required_files_to_install(requirements_downloaded, required_installs) + + self._install_many(target_package_file, dependencies_to_install, scope) + + def _install_many(self, target_package_file: str, dependency_files, scope: Scope): + target_name = get_package_name_from_file(target_package_file) + + with SQLQueryExecutor(connection=self._connection_info) as sqlexecutor: + transaction = SQLTransaction(sqlexecutor, clean_library_name(target_name) + "InstallTransaction") + transaction.begin() + try: + for pkgfile in dependency_files: + self._install_single(sqlexecutor, pkgfile, scope) + self._install_single(sqlexecutor, target_package_file, scope, True) + transaction.commit() + except Exception: + transaction.rollback() + raise RuntimeError("Package installation failed, installed dependencies were rolled back.") + + @staticmethod + def _install_single(sqlexecutor: SQLQueryExecutor, package_file: str, scope: Scope, is_target=False): + name = get_package_name_from_file(package_file) + version = get_package_version_from_file(package_file) + + with tempfile.TemporaryDirectory() as temporary_directory: + prezip = os.path.join(temporary_directory, name + "PREZIP.zip") + with zipfile.ZipFile(prezip, 'w') as zipf: + zipf.write(package_file, os.path.basename(package_file)) + + builder = CreateLibraryBuilder(pkg_name=name, pkg_filename=prezip, scope=scope) + sqlexecutor.execute(builder) + + @staticmethod + def _get_required_files_to_install(pkgfiles, requirements): + return [file for file in pkgfiles + if SQLPackageManager._pkgfile_in_requirements(file, requirements)] + + @staticmethod + def _pkgfile_in_requirements(pkgfile: str, requirements): + pkgname = get_package_name_from_file(pkgfile) + return any([DependencyResolver.clean_requirement_name(pkgname.lower()) == + DependencyResolver.clean_requirement_name(req.lower()) + for req in requirements]) diff --git a/Python/sqlmlutils/sqlbuilder.py b/Python/sqlmlutils/sqlbuilder.py new file mode 100644 index 0000000..5b78d01 --- /dev/null +++ b/Python/sqlmlutils/sqlbuilder.py @@ -0,0 +1,467 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from typing import Callable, List +import abc +import dill +import inspect +import textwrap +from pandas import DataFrame +import warnings + +RETURN_COLUMN_NAME = "return_val" + + +""" +_SQLBuilder implementations are used to generate SQL scripts to execute_function_in_sql Python functions and +create/drop/execute_function_in_sql stored procedures. + +Builder classes use query parametrization whenever possible, falling back to Python string formatting when neccesary. + +The main internal function to execute_function_in_sql SQL statements (_execute_query in the _sqlqueryexecutor module) +takes an implementation _SQLBuilder as an argument. + +All _SQLBuilder classes implement a base_script property. This is the text of the SQL query. Some builder classes +return values in their params property. +""" + + +class SQLBuilder: + + @abc.abstractmethod + def base_script(self) -> str: + pass + + @property + def params(self): + return None + + +class SpeesBuilder(SQLBuilder): + + """_SpeesBuilder objects are used to generate exec sp_execute_external_script SQL queries. + + """ + + def __init__(self, + script: str, + with_results_text: str = "", + input_data_query: str = "", + script_parameters_text: str = ""): + """Instantiate a _SpeesBuilder object. + + :param script: maps to @script parameter in the SQL query parameter + :param with_results_text: with results text used to defined the expected data schema of the SQL query + :param input_data_query: maps to @input_data_1 SQL query parameter + :param script_parameters_text: maps to @params SQL query parameter + """ + self._script = script + self._input_data_query = input_data_query + self._script_parameters_text = script_parameters_text + self._with_results_text = with_results_text + + @property + def base_script(self): + return """ +exec sp_execute_external_script +@language = N'Python', +@script = %s, +@input_data_1 = %s +{script_parameters} +{with_results_text} + """.format(script_parameters=self._script_parameters_text, + with_results_text=self._with_results_text) + + @property + def params(self): + return self._script, self._input_data_query + + +class SpeesBuilderFromFunction(SpeesBuilder): + + """ + _SpeesBuilderFromFunction objects are used to generate SPEES queries based on a function and given arguments. + """ + + _WITH_RESULTS_TEXT = "with result sets((return_val varchar(MAX)))" + + def __init__(self, func: Callable, input_data_query: str = "", *args, **kwargs): + """Instantiate a _SpeesBuilderFromFunction object. + + :param func: function to execute_function_in_sql on the SQL Server. + The spees query is built based on this function. + :param input_data_query: query text for @input_data_1 parameter + :param args: positional arguments to function call in SPEES + :param kwargs: keyword arguments to function call in SPEES + """ + with_inputdf = input_data_query != "" + self._function_text = self._build_wrapper_python_script(func, with_inputdf, *args, **kwargs) + super().__init__(script=self._function_text, + with_results_text=self._WITH_RESULTS_TEXT, + input_data_query=input_data_query) + + # Generates a Python script that encapsulates a user defined function and the arguments to that function. + # This script is "shipped" over the SQL Server machine. + # The function is sent as text. + # The arguments to pass to the function are serialized into their dill hex strings. + # When with_inputdf is True, it specifies that func will take the magic "InputDataSet" as its first arguments. + @staticmethod + def _build_wrapper_python_script(func: Callable, with_inputdf, *args, **kwargs): + dill.settings['recurse'] = True + function_text = SpeesBuilderFromFunction._clean_function_text(inspect.getsource(func)) + args_dill = dill.dumps(kwargs).hex() + pos_args_dill = dill.dumps(args).hex() + function_name = func.__name__ + return """ +{user_function_text} + +import dill +import pandas as pd + +# serialized keyword arguments +args_dill = bytes.fromhex("{args_dill}") +# serialized positional arguments +pos_args_dill = bytes.fromhex("{pos_args_dill}") + +args = dill.loads(args_dill) +pos_args = dill.loads(pos_args_dill) + +# user function name +func = {user_function_name} + +# call user function with serialized arguments +return_val = func{func_arguments} + +return_frame = pd.DataFrame() +# serialize results of user function and put in DataFrame for return through SQL Satellite channel +return_frame["return_val"] = [dill.dumps(return_val).hex()] +OutputDataSet = return_frame +""".format(user_function_text=function_text, + args_dill=args_dill, + pos_args_dill=pos_args_dill, + user_function_name=function_name, + func_arguments=SpeesBuilderFromFunction._func_arguments(with_inputdf)) + + # Call syntax of the user function + # When with_inputdf is true, the user function will always take the "InputDataSet" magic variable as its first + # arguments. + @staticmethod + def _func_arguments(with_inputdf: bool): + return "(InputDataSet, *pos_args, **args)" if with_inputdf else "(*pos_args, **args)" + + @staticmethod + def _clean_function_text(function_text): + return textwrap.dedent(function_text) + + +class StoredProcedureBuilder(SQLBuilder): + + def __init__(self, name: str, script: str, input_params: dict = None, output_params: dict = None): + + """StoredProcedureBuilder SQL stored procedures based on Python functions. + + :param name: name of the stored procedure + :param script: function to base the stored procedure on + :param input_params: input parameters type annotation dictionary for the stored procedure + :param output_params: output parameters type annotation dictionary from the stored procedure + """ + if input_params is None: + input_params = {} + if output_params is None: + output_params = {} + self._script = script + self._name = name + self._input_params = input_params + self._output_params = output_params + self._param_declarations = "" + + names_of_input_args = list(self._input_params) + names_of_output_args = list(self._output_params) + + self._in_parameter_declarations = self.get_declarations(names_of_input_args, self._input_params) + self._out_parameter_declarations = self.get_declarations(names_of_output_args, self._output_params, + outputs=True) + self._script_parameter_text = self.script_parameter_text(names_of_input_args, self._input_params, + names_of_output_args, self._output_params) + + @property + def base_script(self) -> str: + self._param_declarations = self.combine_in_out( + self._in_parameter_declarations, self._out_parameter_declarations) + + return """ +CREATE PROCEDURE {name} + {parameter_declarations} +AS +EXEC sp_execute_external_script +@language = N'Python', +@script = %s +{script_parameters} +""".format(name=self._name, + parameter_declarations=self._param_declarations, + script_parameters=self._script_parameter_text) + + @property + def params(self): + return self._script + + def script_parameter_text(self, in_names: List[str], in_types: dict, out_names: List[str], out_types: dict) -> str: + if not in_names and not out_names: + return "" + + script_params = "" + self._script = "\nfrom pandas import DataFrame\n" + self._script + + in_data_name = "" + out_data_name = "" + + for name in in_names: + if in_types[name] == DataFrame: + in_data_name = name + in_names.remove(name) + break + + for name in out_names: + if out_types[name] == DataFrame: + out_data_name = name + out_names.remove(name) + break + + if in_data_name != "": + script_params += ",\n" + StoredProcedureBuilderFromFunction.get_input_data_set(in_data_name) + + if out_data_name != "": + script_params += ",\n" + StoredProcedureBuilderFromFunction.get_output_data_set(out_data_name) + + if len(in_names) > 0: + script_params += "," + + in_params_declaration = out_params_declaration = "" + in_params_passing = out_params_passing = "" + if len(in_names) > 0: + in_params_declaration = self.get_declarations(in_names, in_types) + in_params_passing = self.get_params_passing(in_names) + + if len(out_names) > 0: + out_params_declaration = self.get_declarations(out_names, out_types, True) + out_params_passing = self.get_params_passing(out_names, True) + + params_declaration = self.combine_in_out(in_params_declaration, out_params_declaration) + params_passing = self.combine_in_out(in_params_passing, out_params_passing) + + if params_declaration != "": + script_params += "\n@params = N'{params_declarations}',\n {params_passing}".format( + params_declarations=params_declaration, + params_passing=params_passing) + + return script_params + + @staticmethod + def combine_in_out(in_str: str = "", out_str: str = ""): + result = in_str + if result != "" and out_str != "": + result += ",\n " + result += out_str + return result + + @staticmethod + def get_input_data_set(name): + return "@input_data_1 = @{name},\n@input_data_1_name = N'{name}'".format(name=name) + + @staticmethod + def get_output_data_set(name): + return "@output_data_1_name = N'{name}'".format(name=name) + + @staticmethod + def get_declarations(names_of_args: List[str], type_annotations: dict, outputs: bool = False): + return ",\n ".join(["@" + name + " {sqltype}{output}".format( + sqltype=StoredProcedureBuilder.to_sql_type(type_annotations.get(name, None)), + output=" OUTPUT" if outputs else "" + ) for name in names_of_args]) + + @staticmethod + def to_sql_type(pytype): + if pytype is None or pytype == str or pytype == DataFrame: + return "nvarchar(MAX)" + elif pytype == int: + return "int" + elif pytype == float: + return "float" + elif pytype == bool: + return "bit" + else: + raise ValueError("Python type: " + str(pytype) + " not supported.") + + @staticmethod + def get_params_passing(names_of_args, outputs: bool = False): + return ",\n ".join(["@" + name + " = " + "@" + name + "{output}".format(output=" OUTPUT" if outputs else "") + for name in names_of_args]) + + +class StoredProcedureBuilderFromFunction(StoredProcedureBuilder): + + """Build query text for stored procedures creation based on Python functions. + +ex: + +name: "MyStoredProcedure" +func: +def foobar(arg1: str, arg2: str, arg3: str): + print(arg1, arg2, arg3) + +===========becomes=================== + +create procedure MyStoredProcedure @arg1 varchar(MAX), @arg2 varchar(MAX), @arg3 varchar(MAX) as + +exec sp_execute_external_script +@language = N'Python', +@script=N' +def foobar(arg1, arg2, arg3): + print(arg1, arg2, arg3) +foobar(arg1=arg1, arg2=arg2, arg3=arg3) +', +@params = N'@arg1 varchar(MAX), @arg2 varchar(MAX), @arg3 varchar(MAX)', +@arg1 = @arg1, +@arg2 = @arg2, +@arg3 = @arg3 + """ + + def __init__(self, name: str, func: Callable, + input_params: dict = None, output_params: dict = None): + """StoredProcedureBuilderFromFunction SQL stored procedures based on Python functions. + + :param name: name of the stored procedure + :param func: function to base the stored procedure on + :param input_params: input parameters type annotation dictionary for the stored procedure + Can you function type annotations instead; if both, they must match + :param output_params: output parameters type annotation dictionary from the stored procedure + """ + if input_params is None: + input_params = {} + if output_params is None: + output_params = {} + self._func = func + self._name = name + self._output_params = output_params + + # Get function information + function_text = textwrap.dedent(inspect.getsource(self._func)) + + argspec = inspect.getfullargspec(self._func) + names_of_input_args = argspec.args + annotations = argspec.annotations + + if argspec.defaults is not None: + warnings.warn("Default values are not supported") + + # Figure out input and output parameter dictionaries + if input_params != {}: + if annotations != {} and annotations != input_params: + raise ValueError("Annotations and input_params do not match!") + self._input_params = input_params + elif annotations != {}: + self._input_params = annotations + elif len(names_of_input_args) == 0: + self._input_params = {} + + names_of_output_args = list(self._output_params) + + if len(names_of_input_args) != len(self._input_params): + raise ValueError("Number of argument annotations doesn't match the number of arguments!") + if set(names_of_input_args) != set(self._input_params.keys()): + raise ValueError("Names of arguments do not match the annotation keys!") + + calling_text = self.get_function_calling_text(self._func, names_of_input_args) + + output_data_set = None + for name in names_of_output_args: + if self._output_params[name] == DataFrame: + names_of_output_args.remove(name) + output_data_set = name + break + + # Creates the base python script to put in the SPEES query. + # Arguments to function are passed by name into script using SPEES @params argument. + self._script = """ +{function_text} +{function_call_text} +{ending} +""".format(function_text=function_text, function_call_text=calling_text, + ending=self.get_ending(self._output_params, output_data_set)) + + self._in_parameter_declarations = self.get_declarations(names_of_input_args, self._input_params) + self._out_parameter_declarations = self.get_declarations(names_of_output_args, self._output_params, + outputs=True) + self._script_parameter_text = self.script_parameter_text(names_of_input_args, self._input_params, + list(self._output_params), self._output_params) + + def script_parameter_text(self, in_names: List[str], in_types: dict, out_names: List[str], out_types: dict) -> str: + if not in_names and not out_names: + self._script = "\nfrom pandas import DataFrame\n" + self._script + return super().script_parameter_text(in_names, in_types, out_names, out_types) + + @staticmethod + def get_function_calling_text(func: Callable, names_of_args: List[str]): + # For a function named foo with signature def foo(arg1, arg2, arg3)... + # kwargs_text is 'arg1=arg1, arg2=arg2, arg3=arg3' + kwargs_text = ", ".join("{}={}".format(name, name) for name in names_of_args) + # returns 'foo(arg1=arg2, arg2=arg2, arg3=arg3)' + return "result = " + func.__name__ + "({})".format(kwargs_text) + + # Convert results to Output data frame and Output parameters + def get_ending(self, output_params: dict, output_data_set: str): + res = """ +if type(result) == DataFrame: + {result_val}""".format(result_val="{out_df} = result".format(out_df=output_data_set + if output_data_set is not None else "OutputDataSet")) + + if len(output_params) > 0 or output_data_set is not None: + res += """ +elif type(result) == dict: + {output_params} +elif result is not None: + raise TypeError("Must return a DataFrame or dictionary with output parameters or None") +""".format(output_params=self.get_output_params(output_params) if len(output_params) > 0 else "pass") + return res + + @staticmethod + def get_output_params(output_params: dict): + return "\n ".join(['{name} = result["{name}"]'.format(name=name) for name in list(output_params)]) + + +class ExecuteStoredProcedureBuilder(SQLBuilder): + + def __init__(self, name: str, **kwargs): + self._name = name + self._kwargs = kwargs + + # Execute the query: exec sproc @var1 = val1, @var2 = val2... + # Does not work with output parameters + @property + def base_script(self) -> str: + parameters = ", ".join(["@{name} = {value}".format(name=name, value=self.format_value(self._kwargs[name])) + for name in self._kwargs]) + return """exec {} {}""".format(self._name, parameters) + + @staticmethod + def format_value(value) -> str: + if isinstance(value, str): + return "'{}'".format(value) + elif isinstance(value, int) or isinstance(value, float): + return str(value) + elif isinstance(value, bool): + return str(int(value)) + else: + raise ValueError("Parameter type {} not supported.".format(str(type(value)))) + + +class DropStoredProcedureBuilder(SQLBuilder): + + def __init__(self, name: str): + self._name = name + + @property + def base_script(self) -> str: + return """ +drop procedure {} +""".format(self._name) diff --git a/Python/sqlmlutils/sqlpythonexecutor.py b/Python/sqlmlutils/sqlpythonexecutor.py new file mode 100644 index 0000000..541995a --- /dev/null +++ b/Python/sqlmlutils/sqlpythonexecutor.py @@ -0,0 +1,209 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from typing import Callable +import dill +from pandas import DataFrame + +from .connectioninfo import ConnectionInfo +from .sqlqueryexecutor import execute_query, execute_raw_query +from .sqlbuilder import SpeesBuilder, SpeesBuilderFromFunction, StoredProcedureBuilder, \ + ExecuteStoredProcedureBuilder, DropStoredProcedureBuilder +from .sqlbuilder import StoredProcedureBuilderFromFunction, RETURN_COLUMN_NAME + + +class SQLPythonExecutor: + + def __init__(self, connection_info: ConnectionInfo): + self._connection_info = connection_info + + def execute_function_in_sql(self, + func: Callable, *args, + input_data_query: str = "", + **kwargs): + """Execute a function in SQL Server. + + :param func: function to execute_function_in_sql. NOTE: This function is shipped to SQL as text. + Functions should be self contained and import statements should be inline. + :param args: positional args to pass to function to execute_function_in_sql. + :param input_data_query: sql query to fill the first argument of the function. The argument gets the result of + the query as a pandas DataFrame (uses the @input_data_1 parameter in sp_execute_external_script) + :param kwargs: keyword arguments to pass to function to execute_function_in_sql. + :return: value returned by func + + >>> from sqlmlutils import ConnectionInfo, SQLPythonExecutor + >>> + >>> def foo(val1, val2): + >>> import math + >>> print(val1) + >>> return [math.cos(val2), math.cos(val2)] + >>> + >>> sqlpy = SQLPythonExecutor(ConnectionInfo("localhost", database="AirlineTestDB")) + >>> ret = sqlpy.execute_function_in_sql(foo, val1="blah", val2=5) + blah + >>> print(ret) + [0.28366218546322625, 0.28366218546322625] + """ + rows = execute_query(SpeesBuilderFromFunction(func, input_data_query, *args, **kwargs), self._connection_info) + return self._get_results(rows) + + def execute_script_in_sql(self, + path_to_script: str, + input_data_query: str = ""): + """Execute a script in SQL Server. + + :param path_to_script: file path to Python script to execute. + :param input_data_query: sql query to fill InputDataSet global variable with. + (@input_data_1 parameter in sp_execute_external_script) + :return: None + + """ + try: + with open(path_to_script, 'r') as script_file: + content = script_file.read() + print("File does exist, using " + path_to_script) + except FileNotFoundError: + raise FileNotFoundError("File does not exist!") + execute_query(SpeesBuilder(content, input_data_query=input_data_query), connection=self._connection_info) + + def execute_sql_query(self, + sql_query: str): + """Execute a sql query in SQL Server. + + :param sql_query: the sql query to execute in the server + :return: table returned by the sql_query + """ + rows = execute_raw_query(conn=self._connection_info, query=sql_query) + df = DataFrame(rows) + + # _mssql's execute_query() returns duplicate keys for indexing, we remove them because they are extraneous + for i in range(len(df.columns)): + try: + del df[i] + except KeyError: + pass + + return df + + def create_sproc_from_function(self, name: str, func: Callable, + input_params: dict = None, output_params: dict = None): + """Create a SQL Server stored procedure based on a Python function. + NOTE: Type annotations are needed either in the function definition or in the input_params dictionary + WARNING: Output parameters can be used when creating the stored procedure, but Stored Procedures with + output parameters other than a single DataFrame cannot be executed with sqlmlutils + + :param name: name of stored procedure. + :param func: function used to define stored procedure. parameters to the function are used to define parameters + to the stored procedure. type annotations of the parameters are used to infer SQL types of parameters to the + stored procedure. currently supported type annotations are "str", "int", "float", and "DataFrame". + :param input_params: optional dictionary of type annotations for each argument to func; + if func has type annotations this is not necessary. If both are provided, they must match + :param output_params optional dictionary of type annotations for each output parameter + :return: True if creation succeeded + + >>> from sqlmlutils import ConnectionInfo, SQLPythonExecutor + >>> + >>> def foo(val1: int, val2: str): + >>> from pandas import DataFrame + >>> print(val2) + >>> df = DataFrame() + >>> df["col1"] = [val1, val1, val1] + >>> return df + >>> + >>> sqlpy = SQLPythonExecutor(ConnectionInfo("localhost", database="AutoRegressTestDB")) + >>> sqlpy.create_sproc_from_function("MyStoredProcedure", foo, with_results_set=True) + >>> + >>> # You can execute_function_in_sql the procedure in the usual way from sql: exec MyStoredProcedure 5, 'bar' + >>> # You can also call the stored procedure from Python + >>> ret = sqlpy.execute_sproc(name="MyStoredProcedure", val1=5, val2="bar") + >>> sqlpy.drop_sproc(name="MyStoredProcedure") + + """ + if input_params is None: + input_params = {} + if output_params is None: + output_params = {} + # Save the stored procedure in database + execute_query(StoredProcedureBuilderFromFunction(name, func, + input_params, output_params), self._connection_info) + return True + + def create_sproc_from_script(self, name: str, path_to_script: str, + input_params: dict = None, output_params: dict = None): + """Create a SQL Server stored procedure based on a Python script + + :param name: name of stored procedure. + :param path_to_script: file path to Python script to create a sproc from. + :param input_params: optional dictionary of type annotations for inputs in the script + :param output_params optional dictionary of type annotations for each output variable + :return: True if creation succeeded + + >>> from sqlmlutils import ConnectionInfo, SQLPythonExecutor + >>> + >>> + >>> sqlpy = SQLPythonExecutor(ConnectionInfo("localhost", database="AutoRegressTestDB")) + >>> sqlpy.create_sproc_from_script(name="script_sproc", path_to_script="path/to/script") + >>> + >>> # This will execute the script in sql; with no inputs or outputs it will just run and return nothing + >>> sqlpy.execute_sproc(name="script_sproc") + >>> sqlpy.drop_sproc(name="script_sproc") + + """ + if input_params is None: + input_params = {} + if output_params is None: + output_params = {} + # Save the stored procedure in database + try: + with open(path_to_script, 'r') as script_file: + content = script_file.read() + print("File does exist, using " + path_to_script) + except FileNotFoundError: + raise FileNotFoundError("File does not exist!") + + execute_query(StoredProcedureBuilder(name, content, + input_params, output_params), self._connection_info) + return True + + def check_sproc(self, name: str) -> bool: + """Check to see if a SQL Server stored procedure exists in the database. + + >>> from sqlmlutils import ConnectionInfo, SQLPythonExecutor + >>> + >>> sqlpy = SQLPythonExecutor(ConnectionInfo("localhost", database="AutoRegressTestDB")) + >>> if sqlpy.check_sproc("MyStoredProcedure"): + >>> print("MyStoredProcedure exists") + >>> else: + >>> print("MyStoredProcedure does not exist") + + :param name: name of stored procedure. + :return: boolean whether the Stored Procedure exists in the database + """ + check_query = "SELECT OBJECT_ID (%s, N'P')" + rows = execute_raw_query(conn=self._connection_info, query=check_query, params=name) + return rows[0][0] is not None + + def execute_sproc(self, name: str, **kwargs) -> DataFrame: + """Call a stored procedure on a SQL Server database. + WARNING: Output parameters can be used when creating the stored procedure, but Stored Procedures with + output parameters other than a single DataFrame cannot be executed with sqlmlutils + + :param name: name of stored procedure. + :param kwargs: keyword arguments to pass to stored procedure + :return: DataFrame representing the output data set of the stored procedure (or empty) + """ + return DataFrame(execute_query(ExecuteStoredProcedureBuilder(name, **kwargs), self._connection_info)) + + def drop_sproc(self, name: str): + """Drop a SQL Server stored procedure if it exists. + + :param name: name of stored procedure. + :return: None + """ + if self.check_sproc(name): + execute_query(DropStoredProcedureBuilder(name), self._connection_info) + + @staticmethod + def _get_results(rows): + hexstring = rows[0][RETURN_COLUMN_NAME] + return dill.loads(bytes.fromhex(hexstring)) diff --git a/Python/sqlmlutils/sqlqueryexecutor.py b/Python/sqlmlutils/sqlqueryexecutor.py new file mode 100644 index 0000000..d17dfc3 --- /dev/null +++ b/Python/sqlmlutils/sqlqueryexecutor.py @@ -0,0 +1,90 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import _mssql +from .connectioninfo import ConnectionInfo +from .sqlbuilder import SQLBuilder + +"""This module is used to actually execute sql queries. It uses the pymssql module under the hood. + +It is mostly setup to work with SQLBuilder objects as defined in sqlbuilder. +""" + + +# This function is best used to execute_function_in_sql a one off query +# (the SQL connection is closed after the query completes). +# If you need to keep the SQL connection open in between queries, you can use the _SQLQueryExecutor class below. +def execute_query(builder, connection: ConnectionInfo): + with SQLQueryExecutor(connection=connection) as executor: + return executor.execute(builder) + + +def execute_raw_query(conn: ConnectionInfo, query, params=()): + with SQLQueryExecutor(connection=conn) as executor: + return executor.execute_query(query, params) + + +def _sql_msg_handler(msgstate, severity, srvname, procname, line, msgtext): + print(msgtext.decode()) + + +class SQLQueryExecutor: + + """_SQLQueryExecutor objects keep a SQL connection open in order to execute_function_in_sql one or more queries. + + This class implements the basic context manager paradigm. + """ + + def __init__(self, connection: ConnectionInfo): + self._connection = connection + + def execute(self, builder: SQLBuilder): + try: + self._mssqlconn.set_msghandler(_sql_msg_handler) + self._mssqlconn.execute_query(builder.base_script, builder.params) + return [row for row in self._mssqlconn] + except Exception as e: + raise RuntimeError(str.format("Error in SQL Execution: {error}", error=str(e))) + + def execute_query(self, query, params): + self._mssqlconn.execute_query(query, params) + return [row for row in self._mssqlconn] + + def __enter__(self): + self._mssqlconn = _mssql.connect(server=self._connection.server, + user=self._connection.uid, + password=self._connection.pwd, + database=self._connection.database) + self._mssqlconn.set_msghandler(_sql_msg_handler) + return self + + def __exit__(self, exception_type, exception_value, traceback): + self._mssqlconn.close() + + +class SQLTransaction: + + def __init__(self, executor: SQLQueryExecutor, name): + self._executor = executor + self._name = name + + def begin(self): + query = """ +declare @transactionname varchar(MAX) = %s; +begin tran @transactionname; + """ + self._executor.execute_query(query, self._name) + + def rollback(self): + query = """ +declare @transactionname varchar(MAX) = %s; +rollback tran @transactionname; + """ + self._executor.execute_query(query, self._name) + + def commit(self): + query = """ +declare @transactionname varchar(MAX) = %s; +commit tran @transactionname; + """ + self._executor.execute_query(query, self._name) diff --git a/Python/sqlmlutils/storedprocedure.py b/Python/sqlmlutils/storedprocedure.py new file mode 100644 index 0000000..041e0cd --- /dev/null +++ b/Python/sqlmlutils/storedprocedure.py @@ -0,0 +1,36 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from pandas import DataFrame + +from .connectioninfo import ConnectionInfo +from .sqlqueryexecutor import execute_query +from .sqlbuilder import ExecuteStoredProcedureBuilder, DropStoredProcedureBuilder + + +class StoredProcedure: + """Represents a SQL Server stored procedure.""" + + def __init__(self, name: str, connection: ConnectionInfo): + """Instantiates a StoredProcedure. Not meant to be called directly, get handles to stored + procedures using get_sproc. + + :param name: name of stored procedure. + """ + self._name = name + self._connection = connection + + def call(self, **kwargs) -> DataFrame: + """Call a stored procedure on a SQL Server database. + + :param kwargs: keyword arguments to pass to stored procedure + :return: DataFrame representing the output data set of the stored procedure (or empty) + """ + return DataFrame(execute_query(ExecuteStoredProcedureBuilder(self._name, **kwargs), self._connection)) + + def drop(self): + """Drop a SQL Server stored procedure. + + :return: None + """ + execute_query(DropStoredProcedureBuilder(self._name), self._connection) diff --git a/Python/tests/execute_function_test.py b/Python/tests/execute_function_test.py new file mode 100644 index 0000000..caa631a --- /dev/null +++ b/Python/tests/execute_function_test.py @@ -0,0 +1,168 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import pytest +from contextlib import redirect_stdout, redirect_stderr +import io +import os + +from sqlmlutils import SQLPythonExecutor +from sqlmlutils import ConnectionInfo +from pandas import DataFrame + +current_dir = os.path.dirname(__file__) +script_dir = os.path.join(current_dir, "scripts") +connection = ConnectionInfo(server="localhost", database="AirlineTestDB") +sqlpy = SQLPythonExecutor(connection) + + +def test_with_named_args(): + def func_with_args(arg1, arg2): + print(arg1) + return arg2 + + output = io.StringIO() + with redirect_stderr(output), redirect_stdout(output): + res = sqlpy.execute_function_in_sql(func_with_args, arg1="str1", arg2="str2") + + assert "str1" in output.getvalue() + assert res == "str2" + + +def test_with_order_args(): + def func_with_order_args(arg1: int, arg2: float): + return arg1 / arg2 + + res = sqlpy.execute_function_in_sql(func_with_order_args, 2, 3.0) + assert res == 2 / 3.0 + res = sqlpy.execute_function_in_sql(func_with_order_args, 3.0, 2) + assert res == 3 / 2.0 + + +def test_return(): + def func_with_return(): + return "returned!" + + res = sqlpy.execute_function_in_sql(func_with_return) + assert res == func_with_return() + + +@pytest.mark.skip(reason="Do we capture warnings?") +def test_warning(): + def func_with_warning(): + import warnings + warnings.warn("WARNING!") + + res = sqlpy.execute_function_in_sql(func_with_warning) + assert res is None + + +def test_with_internal_func(): + def func_with_internal_func(): + def func2(arg1, arg2): + return arg1 + arg2 + + return func2("Suc", "cess") + + res = sqlpy.execute_function_in_sql(func_with_internal_func) + assert res == "Success" + + +@pytest.mark.skip(reason="Cannot currently return a function") +def test_return_func(): + def func2(arg1, arg2): + return arg1 + arg2 + + def func_returns_func(): + def func2(arg1, arg2): + return arg1 + arg2 + + return func2 + + res = sqlpy.execute_function_in_sql(func_returns_func) + assert res == func2 + + +@pytest.mark.skip(reason="Cannot currently return a function outside of environment") +def test_return_func(): + def func2(arg1, arg2): + return arg1 + arg2 + + def func_returns_func(): + return func2 + + res = sqlpy.execute_function_in_sql(func_returns_func) + assert res == func2 + + +def test_with_no_args(): + def func_with_no_args(): + return + + res = sqlpy.execute_function_in_sql(func_with_no_args) + + assert res is None + + +def test_with_data_frame(): + def func_return_df(in_df): + return in_df + + res = sqlpy.execute_function_in_sql(func_return_df, + input_data_query="SELECT TOP 10 * FROM airline5000") + + assert type(res) == DataFrame + assert res.shape == (10, 30) + + +def test_with_variables(): + def func_with_variables(s): + print(s) + + output = io.StringIO() + with redirect_stderr(output), redirect_stdout(output): + sqlpy.execute_function_in_sql(func_with_variables, s="Hello") + + assert "Hello" in output.getvalue() + + output = io.StringIO() + with redirect_stderr(output), redirect_stdout(output): + var_s = "World" + sqlpy.execute_function_in_sql(func_with_variables, s=var_s) + + assert "World" in output.getvalue() + + +def test_execute_query(): + res = sqlpy.execute_sql_query("SELECT TOP 10 * FROM airline5000") + + assert type(res) == DataFrame + assert res.shape == (10, 30) + + +def test_execute_script(): + path = os.path.join(script_dir, "test_script.py") + + output = io.StringIO() + with redirect_stderr(output), redirect_stdout(output): + res = sqlpy.execute_script_in_sql(path_to_script=path, + input_data_query="SELECT TOP 10 * FROM airline5000") + + assert "HelloWorld" in output.getvalue() + assert res is None + + with pytest.raises(FileNotFoundError): + sqlpy.execute_script_in_sql(path_to_script="NonexistentScriptPath", + input_data_query="SELECT TOP 10 * FROM airline5000") + + +def test_stderr(): + def print_to_stderr(): + import sys + sys.stderr.write("Error!") + + output = io.StringIO() + with redirect_stderr(output), redirect_stdout(output): + sqlpy.execute_function_in_sql(print_to_stderr) + + assert "Error!" in output.getvalue() diff --git a/Python/tests/package_helper_functions.py b/Python/tests/package_helper_functions.py new file mode 100644 index 0000000..44f16cc --- /dev/null +++ b/Python/tests/package_helper_functions.py @@ -0,0 +1,13 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from sqlmlutils.sqlqueryexecutor import execute_raw_query + + +def _get_sql_package_table(connection): + query = "select * from sys.external_libraries" + return execute_raw_query(connection, query) + + +def _get_package_names_list(connection): + return {dic['name']: dic['scope'] for dic in _get_sql_package_table(connection)} diff --git a/Python/tests/package_management_file_test.py b/Python/tests/package_management_file_test.py new file mode 100644 index 0000000..eac22da --- /dev/null +++ b/Python/tests/package_management_file_test.py @@ -0,0 +1,266 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import io +import os +import subprocess +import tempfile +from contextlib import redirect_stdout + +import pytest + +import sqlmlutils +from sqlmlutils import SQLPackageManager, SQLPythonExecutor +from package_helper_functions import _get_sql_package_table, _get_package_names_list +from sqlmlutils.packagemanagement.scope import Scope +from sqlmlutils.packagemanagement.pipdownloader import PipDownloader + +connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB") +path_to_packages = os.path.join((os.path.dirname(os.path.realpath(__file__))), "scripts", "test_packages") +_SUCCESS_TOKEN = "SUCCESS" + +pyexecutor = SQLPythonExecutor(connection) +pkgmanager = SQLPackageManager(connection) + +originals = _get_sql_package_table(connection) + +def check_package(package_name: str, exists: bool, class_to_check: str = ""): + if exists: + themodule = __import__(package_name) + assert themodule is not None + assert getattr(themodule, class_to_check) is not None + else: + import pytest + with pytest.raises(Exception): + __import__(package_name) + + +def _execute_sql(script: str) -> bool: + tmpfile = tempfile.NamedTemporaryFile(delete=False) + tmpfile.write(script.encode()) + tmpfile.close() + command = ["sqlcmd", "-d", "AirlineTestDB", "-i", tmpfile.name] + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True).decode() + return _SUCCESS_TOKEN in output + finally: + os.remove(tmpfile.name) + + +def _drop(package_name: str, ddl_name: str): + pkgmanager.uninstall(package_name) + pyexecutor.execute_function_in_sql(check_package, package_name=package_name, exists=False) + + +def _create(module_name: str, package_file: str, class_to_check: str, drop: bool = True): + pyexecutor.execute_function_in_sql(check_package, package_name=module_name, exists=False) + pkgmanager.install(package_file) + pyexecutor.execute_function_in_sql(check_package, package_name=module_name, exists=True, class_to_check=class_to_check) + if drop: + _drop(package_name=module_name, ddl_name=module_name) + + +def _remove_all_new_packages(manager): + libs = {dic['external_library_id']: (dic['name'], dic['scope']) for dic in _get_sql_package_table(connection)} + original_libs = {dic['external_library_id']: (dic['name'], dic['scope']) for dic in originals} + + for lib in libs: + pkg, sc = libs[lib] + if lib not in original_libs: + print("uninstalling" + str(lib)) + if sc: + manager.uninstall(pkg, scope=Scope.private_scope()) + else: + manager.uninstall(pkg, scope=Scope.public_scope()) + else: + if sc != original_libs[lib][1]: + if sc: + manager.uninstall(pkg, scope=Scope.private_scope()) + else: + manager.uninstall(pkg, scope=Scope.public_scope()) + + +packages = ["absl-py==0.1.13", "astor==0.6.2", "bleach==1.5.0", "cryptography==2.2.2", + "html5lib==1.0.1", "Markdown==2.6.11", "numpy==1.14.3", "termcolor==1.1.0", "webencodings==0.5.1"] + +for package in packages: + pipdownloader = PipDownloader(connection, path_to_packages, package) + pipdownloader.download_single() + +def test_install_basic_zip_package(): + package = os.path.join(path_to_packages, "testpackageA-0.0.1.zip") + module_name = "testpackageA" + + _remove_all_new_packages(pkgmanager) + + _create(module_name=module_name, package_file=package, class_to_check="ClassA") + + +def test_install_basic_zip_package_different_name(): + package = os.path.join(path_to_packages, "testpackageA-0.0.1.zip") + module_name = "testpackageA" + + _remove_all_new_packages(pkgmanager) + _create(module_name=module_name, package_file=package, class_to_check="ClassA") + + +def test_install_whl_files(): + packages = ["webencodings-0.5.1-py2.py3-none-any.whl", "html5lib-1.0.1-py2.py3-none-any.whl", + "astor-0.6.2-py2.py3-none-any.whl"] + module_names = ["webencodings", "html5lib", "astor"] + classes_to_check = ["LABELS", "parse", "code_gen"] + + _remove_all_new_packages(pkgmanager) + + for package, module, class_to_check in zip(packages, module_names, classes_to_check): + full_package = os.path.join(path_to_packages, package) + _create(module_name=module, package_file=full_package, class_to_check=class_to_check, drop=False) + + for name in module_names: + _drop(package_name=name, ddl_name=name) + + +def test_install_targz_files(): + packages = ["termcolor-1.1.0.tar.gz"] + module_names = ["termcolor"] + ddl_names = ["termcolor"] + classes_to_check = ["colored"] + + _remove_all_new_packages(pkgmanager) + + for package, module, ddl_name, class_to_check in zip(packages, module_names, ddl_names, classes_to_check): + full_package = os.path.join(path_to_packages, package) + _create(module_name=module, package_file=full_package, class_to_check=class_to_check) + + +def test_install_bad_package_badzipfile(): + + _remove_all_new_packages(pkgmanager) + + with tempfile.TemporaryDirectory() as temporary_directory: + badpackagefile = os.path.join(temporary_directory, "badpackageA-0.0.1.zip") + with open(badpackagefile, "w") as f: + f.write("asdasdasdascsacsadsadas") + with pytest.raises(Exception): + pkgmanager.install(badpackagefile) + + assert "badpackageA" not in _get_package_names_list(connection) + + query = """ +declare @val int; +set @val = (select count(*) from sys.external_libraries where name='badpackageA') +if @val = 0 + print('{}') +""".format(_SUCCESS_TOKEN) + + assert _execute_sql(query) + + +def test_package_already_exists_on_sql_table(): + + _remove_all_new_packages(pkgmanager) + + package = os.path.join(path_to_packages, "testpackageA-0.0.1.zip") + pkgmanager.install(package) + + # Without upgrade + output = io.StringIO() + with redirect_stdout(output): + pkgmanager.install(package, upgrade=False) + assert "exists on server. Set upgrade to True to force upgrade." in output.getvalue() + + # With upgrade + package = os.path.join(path_to_packages, "testpackageA-0.0.2.zip") + pkgmanager.install(package, upgrade=True) + + def check_version(): + import testpackageA + return testpackageA.__version__ + + version = pyexecutor.execute_function_in_sql(check_version) + assert version == "0.0.2" + + pkgmanager.uninstall("testpackageA") + + +def test_upgrade_parameter(): + + _remove_all_new_packages(pkgmanager) + + # Get sql packages + originalsqlpkgs = _get_sql_package_table(connection) + + pkg = os.path.join(path_to_packages, "cryptography-2.2.2-cp35-cp35m-win_amd64.whl") + + output = io.StringIO() + with redirect_stdout(output): + pkgmanager.install(pkg, upgrade=False) + assert "exists on server. Set upgrade to True to force upgrade." in output.getvalue() + + # Assert no additional packages were installed + + sqlpkgs = _get_sql_package_table(connection) + assert len(sqlpkgs) == len(originalsqlpkgs) + + ################# + + def check_version(): + import cryptography as cp + return cp.__version__ + + oldversion = pyexecutor.execute_function_in_sql(check_version) + + pkgmanager.install(pkg, upgrade=True) + + sqlpkgs = _get_sql_package_table(connection) + assert len(sqlpkgs) == len(originalsqlpkgs) + 2 + + version = pyexecutor.execute_function_in_sql(check_version) + assert version == "2.2.2" + assert version > oldversion + + pkgmanager.uninstall("cryptography") + pkgmanager.uninstall("asn1crypto") + + sqlpkgs = _get_sql_package_table(connection) + assert len(sqlpkgs) == len(originalsqlpkgs) + + +# TODO: more tests for drop external library + +def test_scope(): + + _remove_all_new_packages(pkgmanager) + + package = os.path.join(path_to_packages, "testpackageA-0.0.1.zip") + + def get_location(): + import testpackageA + return testpackageA.__file__ + + _revotesterconnection = sqlmlutils.ConnectionInfo(server="localhost", + database="AirlineTestDB", + uid="Tester", + pwd="FakeT3sterPwd!") + revopkgmanager = SQLPackageManager(_revotesterconnection) + revoexecutor = SQLPythonExecutor(_revotesterconnection) + + revopkgmanager.install(package, scope=Scope.private_scope()) + private_location = revoexecutor.execute_function_in_sql(get_location) + + pkg_name = "testpackageA" + + pyexecutor.execute_function_in_sql(check_package, package_name=pkg_name, exists=False) + + revopkgmanager.uninstall(pkg_name, scope=Scope.private_scope()) + + revopkgmanager.install(package, scope=Scope.public_scope()) + public_location = revoexecutor.execute_function_in_sql(get_location) + + assert private_location != public_location + pyexecutor.execute_function_in_sql(check_package, package_name=pkg_name, exists=True, class_to_check='ClassA') + + revopkgmanager.uninstall(pkg_name, scope=Scope.public_scope()) + + revoexecutor.execute_function_in_sql(check_package, package_name=pkg_name, exists=False) + pyexecutor.execute_function_in_sql(check_package, package_name=pkg_name, exists=False) diff --git a/Python/tests/package_management_pypi_test.py b/Python/tests/package_management_pypi_test.py new file mode 100644 index 0000000..a9056b1 --- /dev/null +++ b/Python/tests/package_management_pypi_test.py @@ -0,0 +1,232 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import sqlmlutils +import os +import pytest +from sqlmlutils import SQLPythonExecutor, SQLPackageManager +from sqlmlutils.packagemanagement.scope import Scope +from package_helper_functions import _get_sql_package_table, _get_package_names_list +import io +from contextlib import redirect_stdout + + +def _drop_all_ddl_packages(conn): + pkgs = _get_sql_package_table(conn) + for pkg in pkgs: + try: + SQLPackageManager(conn)._drop_sql_package(pkg['name'], scope=Scope.private_scope()) + except Exception: + pass + + +server = os.environ.get("SQLPY_TEST_SERVER", "localhost") +database = os.environ.get("SQLPY_TEST_DB", "AirlineTestDB") +uid = os.environ.get("SQLPY_TEST_UID", "") +pwd = os.environ.get("SQLPY_TEST_PWD", "") +connection = sqlmlutils.ConnectionInfo(server=server, database=database, uid=uid, pwd=pwd) +pyexecutor = SQLPythonExecutor(connection) +pkgmanager = SQLPackageManager(connection) +_drop_all_ddl_packages(connection) + + +def _package_exists(module_name: str): + mod = __import__(module_name) + return mod is not None + + +def _package_no_exist(module_name: str): + import pytest + with pytest.raises(Exception): + __import__(module_name) + return True + + +def test_install_tensorflow_and_keras(): + def use_tensorflow(): + import tensorflow as tf + node1 = tf.constant(3.0, tf.float32) + return str(node1.dtype) + + def use_keras(): + import keras + + pkgmanager.install("tensorflow") + val = pyexecutor.execute_function_in_sql(use_tensorflow) + assert 'float32' in val + + pkgmanager.install("keras") + pyexecutor.execute_function_in_sql(use_keras) + pkgmanager.uninstall("keras") + val = pyexecutor.execute_function_in_sql(_package_no_exist, "keras") + assert val + + pkgmanager.uninstall("tensorflow") + val = pyexecutor.execute_function_in_sql(_package_no_exist, "tensorflow") + assert val + + _drop_all_ddl_packages(connection) + + +def test_install_many_packages(): + packages = ["multiprocessing_on_dill", "simplejson"] + + for package in packages: + pkgmanager.install(package, upgrade=True) + val = pyexecutor.execute_function_in_sql(_package_exists, module_name=package) + assert val + + pkgmanager.uninstall(package) + val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package) + assert val + + _drop_all_ddl_packages(connection) + + +def test_install_version(): + package = "simplejson" + v = "3.0.3" + + def _package_version_exists(module_name: str, version: str): + mod = __import__(module_name) + return mod.__version__ == version + + pkgmanager.install(package, version=v) + val = pyexecutor.execute_function_in_sql(_package_version_exists, module_name=package, version=v) + assert val + + pkgmanager.uninstall(package) + val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package) + assert val + + _drop_all_ddl_packages(connection) + + +def test_dependency_resolution(): + package = "multiprocessing_on_dill" + + pkgmanager.install(package, upgrade=True) + val = pyexecutor.execute_function_in_sql(_package_exists, module_name=package) + assert val + + pkgs = _get_package_names_list(connection) + + assert package in pkgs + assert "pyreadline" in pkgs + + pkgmanager.uninstall(package) + val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package) + assert val + + _drop_all_ddl_packages(connection) + + +def test_upgrade_parameter(): + + pkg = "cryptography" + + # Get sql packages + originalsqlpkgs = _get_sql_package_table(connection) + + output = io.StringIO() + with redirect_stdout(output): + pkgmanager.install(pkg, upgrade=False) + assert "exists on server. Set upgrade to True to force upgrade." in output.getvalue() + + # Assert no additional packages were installed + + sqlpkgs = _get_sql_package_table(connection) + assert len(sqlpkgs) == len(originalsqlpkgs) + + ################# + + def check_version(): + import cryptography as cp + return cp.__version__ + + oldversion = pyexecutor.execute_function_in_sql(check_version) + + pkgmanager.install(pkg, upgrade=True) + + afterinstall = _get_sql_package_table(connection) + assert len(afterinstall) > len(originalsqlpkgs) + + version = pyexecutor.execute_function_in_sql(check_version) + assert version > oldversion + + pkgmanager.uninstall("cryptography") + + sqlpkgs = _get_sql_package_table(connection) + assert len(sqlpkgs) == len(afterinstall) - 1 + + _drop_all_ddl_packages(connection) + + +def test_install_abslpy(): + pkgmanager.install("absl-py") + + def useit(): + import absl + return absl.__file__ + + pyexecutor.execute_function_in_sql(useit) + + pkgmanager.uninstall("absl-py") + + def dontuseit(): + import pytest + with pytest.raises(Exception): + import absl + + pyexecutor.execute_function_in_sql(dontuseit) + + _drop_all_ddl_packages(connection) + + +@pytest.mark.skip(reason="Theano depends on a conda package libpython? lazylinker issue") +def test_install_theano(): + pkgmanager.install("Theano") + + def useit(): + import theano.tensor as T + return str(T) + + pyexecutor.execute_function_in_sql(useit) + + pkgmanager.uninstall("Theano") + + pkgmanager.install("theano") + pyexecutor.execute_function_in_sql(useit) + pkgmanager.uninstall("theano") + + _drop_all_ddl_packages(connection) + + +def test_already_installed_popular_ml_packages(): + installedpackages = ["numpy", "scipy", "pandas", "matplotlib", "seaborn", "bokeh", "nltk", "statsmodels"] + + sqlpkgs = _get_sql_package_table(connection) + for package in installedpackages: + pkgmanager.install(package) + newsqlpkgs = _get_sql_package_table(connection) + assert len(sqlpkgs) == len(newsqlpkgs) + + +def test_installing_popular_ml_packages(): + newpackages = ["plotly", "cntk", "gensim"] + + def checkit(pkgname): + val = __import__(pkgname) + return str(val) + + for package in newpackages: + pkgmanager.install(package) + pyexecutor.execute_function_in_sql(checkit, pkgname=package) + + _drop_all_ddl_packages(connection) + + +# TODO: find a bad pypi package to test this scenario +def test_install_bad_pypi_package(): + pass + diff --git a/Python/tests/samples/sample_linear_regression_test.py b/Python/tests/samples/sample_linear_regression_test.py new file mode 100644 index 0000000..25b55b2 --- /dev/null +++ b/Python/tests/samples/sample_linear_regression_test.py @@ -0,0 +1,24 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import sqlmlutils + + +def linear_regression(input_df, x_col, y_col): + from sklearn import linear_model + + X = input_df[[x_col]] + y = input_df[y_col] + + lr = linear_model.LinearRegression() + lr.fit(X, y) + + return lr + + +sqlpy = sqlmlutils.SQLPythonExecutor(sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")) +sql_query = "select top 1000 CRSDepTime, CRSArrTime from airline5000" +regression_model = sqlpy.execute_function_in_sql(linear_regression, input_data_query=sql_query, + x_col="CRSDepTime", y_col="CRSArrTime") +print(regression_model) +print(regression_model.coef_) diff --git a/Python/tests/samples/sample_scatter_plot_test.py b/Python/tests/samples/sample_scatter_plot_test.py new file mode 100644 index 0000000..1f3ce34 --- /dev/null +++ b/Python/tests/samples/sample_scatter_plot_test.py @@ -0,0 +1,35 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import sqlmlutils +from PIL import Image + + +def scatter_plot(input_df, x_col, y_col): + import matplotlib.pyplot as plt + import io + + title = x_col + " vs. " + y_col + + plt.scatter(input_df[x_col], input_df[y_col]) + plt.xlabel(x_col) + plt.ylabel(y_col) + plt.title(title) + + # Save scatter plot image as a png + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + + # Returns the bytes of the png to the client + return buf + + +sqlpy = sqlmlutils.SQLPythonExecutor(sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")) + +sql_query = "select top 100 * from airline5000" +plot_data = sqlpy.execute_function_in_sql(func=scatter_plot, input_data_query=sql_query, + x_col="ArrDelay", y_col="CRSDepTime") +im = Image.open(plot_data) +im.show() +#im.save("scatter_test.png") diff --git a/Python/tests/samples/sample_simple_function_test.py b/Python/tests/samples/sample_simple_function_test.py new file mode 100644 index 0000000..5c3ea26 --- /dev/null +++ b/Python/tests/samples/sample_simple_function_test.py @@ -0,0 +1,14 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import sqlmlutils + + +def foo(): + return "bar" + + +sqlpython = sqlmlutils.SQLPythonExecutor(sqlmlutils.ConnectionInfo(server="localhost", database="master")) +result = sqlpython.execute_function_in_sql(foo) +assert result == "bar" + diff --git a/Python/tests/samples/sample_stored_procedure.py b/Python/tests/samples/sample_stored_procedure.py new file mode 100644 index 0000000..a7a67b7 --- /dev/null +++ b/Python/tests/samples/sample_stored_procedure.py @@ -0,0 +1,47 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import sqlmlutils +import pytest + +def principal_components(input_table: str, output_table: str): + import sqlalchemy + from urllib import parse + import pandas as pd + from sklearn.decomposition import PCA + + # Internal ODBC connection string used by process executing inside SQL Server + connection_string = "Driver=SQL Server;Server=localhost;Database=AirlineTestDB;Trusted_Connection=Yes;" + engine = sqlalchemy.create_engine("mssql+pyodbc:///?odbc_connect={}".format(parse.quote_plus(connection_string))) + + input_df = pd.read_sql("select top 200 ArrDelay,CRSDepTime,DayOfWeek from {}".format(input_table), engine).dropna() + + + pca = PCA(n_components=2) + components = pca.fit_transform(input_df) + + output_df = pd.DataFrame(components) + output_df.to_sql(output_table, engine, if_exists="replace") + + +connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB") + +input_table = "airline5000" +output_table = "AirlineDemoPrincipalComponents" + +sp_name = "SavePrincipalComponents" + +sqlpy = sqlmlutils.SQLPythonExecutor(connection) + +if sqlpy.check_sproc(sp_name): + sqlpy.drop_sproc(sp_name) + +sqlpy.create_sproc_from_function(sp_name, principal_components) + +# You can check the stored procedure exists in the db with this: +assert sqlpy.check_sproc(sp_name) + +sqlpy.execute_sproc(sp_name, input_table=input_table, output_table=output_table) + +sqlpy.drop_sproc(sp_name) +assert not sqlpy.check_sproc(sp_name) \ No newline at end of file diff --git a/Python/tests/scripts/test_packages/testpackageA-0.0.1.zip b/Python/tests/scripts/test_packages/testpackageA-0.0.1.zip new file mode 100644 index 0000000..f24d813 Binary files /dev/null and b/Python/tests/scripts/test_packages/testpackageA-0.0.1.zip differ diff --git a/Python/tests/scripts/test_packages/testpackageA-0.0.2.zip b/Python/tests/scripts/test_packages/testpackageA-0.0.2.zip new file mode 100644 index 0000000..df2c7e4 Binary files /dev/null and b/Python/tests/scripts/test_packages/testpackageA-0.0.2.zip differ diff --git a/Python/tests/scripts/test_packages/testpackageA/MANIFEST b/Python/tests/scripts/test_packages/testpackageA/MANIFEST new file mode 100644 index 0000000..f130d7b --- /dev/null +++ b/Python/tests/scripts/test_packages/testpackageA/MANIFEST @@ -0,0 +1,4 @@ +# file GENERATED by distutils, do NOT edit +setup.py +testpackageA\ClassA.py +testpackageA\__init__.py diff --git a/Python/tests/scripts/test_packages/testpackageA/dist/testpackageA-0.0.1.zip b/Python/tests/scripts/test_packages/testpackageA/dist/testpackageA-0.0.1.zip new file mode 100644 index 0000000..f24d813 Binary files /dev/null and b/Python/tests/scripts/test_packages/testpackageA/dist/testpackageA-0.0.1.zip differ diff --git a/Python/tests/scripts/test_packages/testpackageA/setup.py b/Python/tests/scripts/test_packages/testpackageA/setup.py new file mode 100644 index 0000000..a4e9ff9 --- /dev/null +++ b/Python/tests/scripts/test_packages/testpackageA/setup.py @@ -0,0 +1,9 @@ +from distutils.core import setup + +setup( + name='testpackageA' , + packages=['testpackageA'], + version='0.0.1', + description='Test package for python package management.', + author='Microsoft' +) diff --git a/Python/tests/scripts/test_packages/testpackageA/testpackageA/ClassA.py b/Python/tests/scripts/test_packages/testpackageA/testpackageA/ClassA.py new file mode 100644 index 0000000..c50b501 --- /dev/null +++ b/Python/tests/scripts/test_packages/testpackageA/testpackageA/ClassA.py @@ -0,0 +1,8 @@ +class ClassA: + + def __init__(self, val): + self._val = val + + @property + def val(self): + return self._val \ No newline at end of file diff --git a/Python/tests/scripts/test_packages/testpackageA/testpackageA/__init__.py b/Python/tests/scripts/test_packages/testpackageA/testpackageA/__init__.py new file mode 100644 index 0000000..ebd7fef --- /dev/null +++ b/Python/tests/scripts/test_packages/testpackageA/testpackageA/__init__.py @@ -0,0 +1 @@ +from .ClassA import ClassA \ No newline at end of file diff --git a/Python/tests/scripts/test_script.py b/Python/tests/scripts/test_script.py new file mode 100644 index 0000000..e982e2b --- /dev/null +++ b/Python/tests/scripts/test_script.py @@ -0,0 +1,9 @@ +def foo(t1, t2, t3): + print(t1 + t2) + print(t3) + return t3 + + +res = foo("Hello","World",InputDataSet) + +print("Testing output!") diff --git a/Python/tests/scripts/test_script_no_out_params.py b/Python/tests/scripts/test_script_no_out_params.py new file mode 100644 index 0000000..572f862 --- /dev/null +++ b/Python/tests/scripts/test_script_no_out_params.py @@ -0,0 +1,9 @@ +def foo(t1, t2, t3): + print(t1 + t2) + print(t3) + return t3 + + +res = foo(t1,t2,t3) + +print("Testing output!") diff --git a/Python/tests/scripts/test_script_no_params.py b/Python/tests/scripts/test_script_no_params.py new file mode 100644 index 0000000..9e7567f --- /dev/null +++ b/Python/tests/scripts/test_script_no_params.py @@ -0,0 +1,9 @@ +def foo(t1, t2, t3): + print(t1 + t2) + print(t3) + return t3 + + +res = foo("No ", "Inputs", "Required") + +print("Testing output!") diff --git a/Python/tests/scripts/test_script_out_param.py b/Python/tests/scripts/test_script_out_param.py new file mode 100644 index 0000000..5d679ff --- /dev/null +++ b/Python/tests/scripts/test_script_out_param.py @@ -0,0 +1,7 @@ +def foo(t1, t2, t3): + return str(t1)+str(t2) + + +res = foo(t1,t2,t3) + +print("Testing output!") diff --git a/Python/tests/scripts/test_script_sproc_out_df.py b/Python/tests/scripts/test_script_sproc_out_df.py new file mode 100644 index 0000000..c5034bd --- /dev/null +++ b/Python/tests/scripts/test_script_sproc_out_df.py @@ -0,0 +1,10 @@ +def foo(t1, t2, t3): + print(t1) + print(t2) + print(t3) + return t3 + + +OutputDataSet = foo(t1,t2,t3) + +print("Testing output!") diff --git a/Python/tests/stored_procedure_test.py b/Python/tests/stored_procedure_test.py new file mode 100644 index 0000000..6362b8e --- /dev/null +++ b/Python/tests/stored_procedure_test.py @@ -0,0 +1,461 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import pytest +import sqlmlutils +from contextlib import redirect_stdout +from subprocess import Popen, PIPE, STDOUT +from pandas import DataFrame +import io +import os + +current_dir = os.path.dirname(__file__) +script_dir = os.path.join(current_dir, "scripts") +conn = sqlmlutils.ConnectionInfo(database="AirlineTestDB") +sqlpy = sqlmlutils.SQLPythonExecutor(conn) + + +################### +# No output tests # +################### + +def test_no_output(): + def my_func(): + print("blah blah blah") + + name = "test_no_output" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name, my_func) + assert sqlpy.check_sproc(name) + + x = sqlpy.execute_sproc(name) + assert type(x) == DataFrame + assert x.empty + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_no_output_mixed_args(): + def mixed(val1: int, val2: str, val3: float, val4: bool): + print(val1, val2, val3, val4) + + name = "test_no_output_mixed_args" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name, mixed) + buf = io.StringIO() + with redirect_stdout(buf): + sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, val4=True) + assert "5 blah 15.5 True" in buf.getvalue() + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_no_output_mixed_args_in_df(): + def mixed(val1: int, val2: str, val3: float, val4: bool, val5: DataFrame): + print(val1, val2, val3, val4) + print(val5) + + name = "test_no_output_mixed_args_in_df" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name, mixed) + buf = io.StringIO() + with redirect_stdout(buf): + sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, val4=False, val5="SELECT TOP 2 * FROM airline5000") + assert "5 blah 15.5 False" in buf.getvalue() + assert "ArrTime" in buf.getvalue() + assert "CRSDepTime" in buf.getvalue() + assert "DepTime" in buf.getvalue() + assert "CancellationCode" in buf.getvalue() + assert "DayOfWeek" in buf.getvalue() + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_no_output_mixed_args_in_df_in_params(): + def mixed(val1, val2, val3, val4, val5): + print(val1, val2, val3, val5) + print(val4) + + in_params = {"val1": int, "val2": str, "val3": float, "val4": DataFrame, "val5": bool} + name = "test_no_output_mixed_args_in_df_in_params" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name=name, func=mixed, input_params=in_params) + buf = io.StringIO() + with redirect_stdout(buf): + sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, val4="SELECT TOP 2 * FROM airline5000", val5=False) + assert "5 blah 15.5 False" in buf.getvalue() + assert "ArrTime" in buf.getvalue() + assert "CRSDepTime" in buf.getvalue() + assert "DepTime" in buf.getvalue() + assert "CancellationCode" in buf.getvalue() + assert "DayOfWeek" in buf.getvalue() + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +################ +# Test outputs # +################ + +def test_out_df_no_params(): + def no_params(): + df = DataFrame() + df["col1"] = [1, 2, 3, 4, 5] + return df + + name = "test_out_df_no_params" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name, no_params) + assert sqlpy.check_sproc(name) + + df = sqlpy.execute_sproc(name) + assert list(df.iloc[:,0] == [1, 2, 3, 4, 5]) + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_out_df_with_args(): + def my_func_with_args(arg1: str, arg2: str): + return DataFrame({"arg1": [arg1], "arg2": [arg2]}) + + name = "test_out_df_with_args" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name, my_func_with_args) + assert sqlpy.check_sproc(name) + + vals = [("arg1val", "arg2val"), ("asd", "Asd"), ("Qwe", "Qwe"), ("zxc", "Asd")] + + for values in vals: + arg1 = values[0] + arg2 = values[1] + res = sqlpy.execute_sproc(name, arg1=arg1, arg2=arg2) + assert res[0][0] == arg1 + assert res[1][0] == arg2 + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_out_df_in_df(): + def in_data(in_df: DataFrame): + return in_df + + name = "test_out_df_in_df" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name, in_data) + assert sqlpy.check_sproc(name) + + res = sqlpy.execute_sproc(name, in_df="SELECT TOP 10 * FROM airline5000") + + assert type(res) == DataFrame + assert res.shape == (10, 30) + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_out_df_mixed_args_in_df(): + def mixed(val1: int, val2: str, val3: float, val4: DataFrame, val5: bool): + print(val1, val2, val3, val5) + if val5 and val1 == 5 and val2 == "blah" and val3 == 15.5: + return val4 + else: + return None + + name = "test_out_df_mixed_args_in_df" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name, mixed) + + res = sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, + val4="SELECT TOP 10 * FROM airline5000", val5=True) + + assert type(res) == DataFrame + assert res.shape == (10, 30) + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_out_df_mixed_in_params_in_df(): + def mixed(val1, val2, val3, val4, val5): + print(val1, val2, val3, val5) + if val5 and val1 == 5 and val2 == "blah" and val3 == 15.5: + return val4 + else: + return None + + name = "test_out_df_mixed_in_params_in_df" + sqlpy.drop_sproc(name) + + input_params = {"val1": int, "val2": str, "val3": float, "val4": DataFrame, "val5": bool} + + sqlpy.create_sproc_from_function(name, mixed, input_params=input_params) + assert sqlpy.check_sproc(name) + + res = sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, + val4="SELECT TOP 10 * FROM airline5000", val5=True) + + assert type(res) == DataFrame + assert res.shape == (10, 30) + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_out_of_order_args(): + def mixed(val1, val2, val3, val4, val5): + return DataFrame({"val1": [val1], "val2": [val2], "val3": [val3], "val5": [val5]}) + + in_params = {"val2": str, "val3": float, "val5": bool, "val4": DataFrame, "val1": int} + + name = "test_out_of_order_args" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_function(name=name, func=mixed, input_params=in_params) + assert sqlpy.check_sproc(name) + + v1 = 5 + v2 = "blah" + v3 = 15.5 + v4 = "SELECT TOP 10 * FROM airline5000" + res = sqlpy.execute_sproc(name, val5=False, val3=v3, val4=v4, val1=v1, val2=v2) + + assert res[0][0] == v1 + assert res[1][0] == v2 + assert res[2][0] == v3 + assert not res[3][0] + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +# TODO: Output Params execution not currently supported +def test_in_param_out_param(): + def in_out(t1, t2, t3): + print(t2) + print(t3) + res = "Hello " + t1 + return {'out_df': t3, 'res': res} + + name = "test_in_param_out_param" + sqlpy.drop_sproc(name) + + input_params = {"t1": str, "t2": int, "t3": DataFrame} + output_params = {"res": str, "out_df": DataFrame} + + sqlpy.create_sproc_from_function(name, in_out, input_params=input_params, output_params=output_params) + assert sqlpy.check_sproc(name) + + # Out params don't currently work so we use sqlcmd to test the output param sproc + sql_str = "DECLARE @res nvarchar(max) EXEC test_in_param_out_param @t2 = 213, @t1 = N'Hello', " \ + "@t3 = N'select top 10 * from airline5000', @res = @res OUTPUT SELECT @res as N'@res'" + p = Popen(["sqlcmd", "-S", conn.server, "-E", "-d", conn.database, "-Q", sql_str], + shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) + output = p.stdout.read() + assert "Hello Hello" in output.decode() + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_in_df_out_df_dict(): + def func(in_df: DataFrame): + return {"out_df": in_df} + + name = "test_in_df_out_df_dict" + sqlpy.drop_sproc(name) + + output_params = {"out_df": DataFrame} + + sqlpy.create_sproc_from_function(name, func, output_params=output_params) + assert sqlpy.check_sproc(name) + + res = sqlpy.execute_sproc(name, in_df="SELECT TOP 10 * FROM airline5000") + + assert type(res) == DataFrame + assert res.shape == (10, 30) + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +################ +# Script Tests # +################ + +def test_script_no_params(): + script = os.path.join(script_dir, "test_script_no_params.py") + + name = "test_script_no_params" + sqlpy.drop_sproc(name) + + sqlpy.create_sproc_from_script(name, script) + assert sqlpy.check_sproc(name) + + buf = io.StringIO() + with redirect_stdout(buf): + sqlpy.execute_sproc(name) + assert "No Inputs" in buf.getvalue() + assert "Required" in buf.getvalue() + assert "Testing output!" in buf.getvalue() + assert "HelloWorld" not in buf.getvalue() + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_script_no_out_params(): + script = os.path.join(script_dir, "test_script_no_out_params.py") + + name = "test_script_no_out_params" + sqlpy.drop_sproc(name) + + input_params = {"t1": str, "t2": str, "t3": int} + + sqlpy.create_sproc_from_script(name, script, input_params) + assert sqlpy.check_sproc(name) + + buf = io.StringIO() + with redirect_stdout(buf): + sqlpy.execute_sproc(name, t1="Hello", t2="World", t3=312) + assert "HelloWorld" in buf.getvalue() + assert "312" in buf.getvalue() + assert "Testing output!" in buf.getvalue() + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +def test_script_out_df(): + script = os.path.join(script_dir, "test_script_sproc_out_df.py") + + name = "test_script_out_df" + sqlpy.drop_sproc(name) + + input_params = {"t1": str, "t2": int, "t3": DataFrame} + + sqlpy.create_sproc_from_script(name, script, input_params) + assert sqlpy.check_sproc(name) + + res = sqlpy.execute_sproc(name, t1="Hello", t2=2313, t3="SELECT TOP 10 * FROM airline5000") + + assert type(res) == DataFrame + assert res.shape == (10, 30) + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +#TODO: Output Params execution not currently supported +def test_script_out_param(): + script = os.path.join(script_dir, "test_script_out_param.py") + + name = "test_script_out_param" + sqlpy.drop_sproc(name) + + input_params = {"t1": str, "t2": int, "t3": DataFrame} + output_params = {"res": str} + + sqlpy.create_sproc_from_script(name, script, input_params, output_params) + assert sqlpy.check_sproc(name) + + # Out params don't currently work so we use sqlcmd to test the output param sproc + sql_str = "DECLARE @res nvarchar(max) EXEC test_script_out_param @t2 = 123, @t1 = N'Hello', " \ + "@t3 = N'select top 10 * from airline5000', @res = @res OUTPUT SELECT @res as N'@res'" + p = Popen(["sqlcmd", "-S", conn.server, "-E", "-d", conn.database, "-Q", sql_str], + shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) + output = p.stdout.read() + assert "Hello123" in output.decode() + + sqlpy.drop_sproc(name) + assert not sqlpy.check_sproc(name) + + +################## +# Negative Tests # +################## + +def test_execute_bad_param_types(): + def bad_func(input1: bin): + pass + + with pytest.raises(ValueError): + sqlpy.create_sproc_from_function("BadParam", bad_func) + + def func(input1: bool): + pass + name = "BadInput" + sqlpy.drop_sproc(name) + sqlpy.create_sproc_from_function(name, func) + assert sqlpy.check_sproc(name) + + with pytest.raises(RuntimeError): + sqlpy.execute_sproc(name, input1="Hello!") + + +def test_create_bad_name(): + def foo(): + return 1 + with pytest.raises(RuntimeError): + sqlpy.create_sproc_from_function("'''asd''asd''asd", foo) + + +def test_no_output_bad_num_args(): + def mixed(val1: str, val2, val3, val4): + print(val1, val2, val3) + print(val4) + + name = "test_no_output_bad_num_args" + sqlpy.drop_sproc(name) + + with pytest.raises(ValueError): + sqlpy.create_sproc_from_function(name=name, func=mixed) + + def func(val1, val2, val3, val4): + print(val1, val2, val3) + print(val4) + + input_params = {"val1": int, "val4": str, "val5": int, "BADVAL": str} + sqlpy.drop_sproc(name) + + with pytest.raises(ValueError): + sqlpy.create_sproc_from_function(name=name, func=func, input_params=input_params) + + input_params = {"val1": int, "val2": int, "val3": str} + sqlpy.drop_sproc(name) + + with pytest.raises(ValueError): + sqlpy.create_sproc_from_function(name=name, func=func, input_params=input_params) + + +def test_annotation_vs_input_param(): + def foo(val1: str, val2: int, val3: int): + print(val1) + print(val2) + return val3 + + name = "test_input_param_override_error" + input_params = {"val1": str, "val2": int, "val3": DataFrame} + + sqlpy.drop_sproc(name) + with pytest.raises(ValueError): + sqlpy.create_sproc_from_function(name=name, func=foo, input_params=input_params) + + +def test_bad_script_path(): + with pytest.raises(FileNotFoundError): + sqlpy.create_sproc_from_script(name="badScript", path_to_script="NonexistentScriptPath") + diff --git a/R/DESCRIPTION b/R/DESCRIPTION new file mode 100644 index 0000000..1bf7638 --- /dev/null +++ b/R/DESCRIPTION @@ -0,0 +1,19 @@ +Package: sqlmlutils +Type: Package +Title: Wraps R code into executable SQL Server stored procedures +Version: 0.5.0 +Author: Microsoft Corporation +Maintainer: Microsoft Corporation +Depends: + R (>= 3.2.2) +Imports: + RODBC, RODBCext, tools, methods, utils +Description: Provides a set of functions allowing the user + to wrap their R script into a TSQL stored procedure, register + that stored procedure with a database, and test it from an R + development environment. +License: MIT + file LICENSE +Copyright: Copyright 2016 Microsoft Corporation +RoxygenNote: 6.0.1 +Suggests: testthat, + roxygen2 diff --git a/R/LICENSE b/R/LICENSE new file mode 100644 index 0000000..e39502a --- /dev/null +++ b/R/LICENSE @@ -0,0 +1,25 @@ +------------------------------------------- START OF LICENSE ----------------------------------------- +sqlmlutils + +MIT License + +Copyright (c) Microsoft Corporation. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +----------------------------------------------- END OF LICENSE ------------------------------------------ diff --git a/R/NAMESPACE b/R/NAMESPACE new file mode 100644 index 0000000..971758e --- /dev/null +++ b/R/NAMESPACE @@ -0,0 +1,17 @@ +# Generated by roxygen2: do not edit by hand + +export(checkSproc) +export(connectionInfo) +export(createSprocFromFunction) +export(createSprocFromScript) +export(dropSproc) +export(executeFunctionInSQL) +export(executeSQLQuery) +export(executeScriptInSQL) +export(executeSproc) +export(sql_install.packages) +export(sql_installed.packages) +export(sql_remove.packages) +import(RODBC) +importFrom(RODBCext,sqlExecute) +importFrom(utils,tail) diff --git a/R/R/executeInSQL.R b/R/R/executeInSQL.R new file mode 100644 index 0000000..7c26735 --- /dev/null +++ b/R/R/executeInSQL.R @@ -0,0 +1,293 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + + + +#' +#'Execute a function in SQL +#' +#'@param driver The driver to use for the connection - defaults to SQL Server +#'@param server The server to connect to - defaults to localhost +#'@param database The database to connect to - defaults to master +#'@param uid The user id for the connection. If uid is NULL, default to Trusted Connection +#'@param pwd The password for the connection. If uid is not NULL, pwd is required +#' +#'@return A fully formed connection string +#' +#' +#'@examples +#'\dontrun{ +#' +#' connectionInfo() +#' [1] "Driver={SQL Server};Server=localhost;Database=master;Trusted_Connection=Yes;" +#' +#' connectionInfo(server="ServerName", database="AirlineTestDB", uid="username", pwd="pass") +#' [1] "Driver={SQL Server};Server=ServerName;Database=AirlineTestDB;uid=username;pwd=pass;" +#' } +#' +#' +#'@export +connectionInfo <- function(driver = "SQL Server", server = "localhost", database = "master", + uid = NULL, pwd = NULL) { + authorization <- "Trusted_Connection=Yes" + + if (!is.null(uid)) { + if (is.null(pwd)) { + stop("Need a password if using uid") + } else { + authorization = sprintf("uid=%s;pwd=%s",uid,pwd) + } + } + + connection <- sprintf("Driver={%s};Server=%s;Database=%s;%s;", driver, server, database, authorization) + connection +} + + + +#' +#'Execute a function in SQL +#' +#'@param connectionString character string. The connectionString to the database +#'@param func closure. The function to execute +#'@param ... A named list of arguments to pass into the function +#'@param inputDataQuery character string. A string to query the database. +#' The result of the query will be put into a data frame into the first argument in the function +#' +#'@return The returned value from the function +#' +#'@seealso +#'\code{\link{executeScriptInSQL}} to execute a script file instead of a function in SQL +#' +#' +#'@examples +#'\dontrun{ +#' connection <- connectionInfo(database = "AirlineTestDB") +#' +#' foo <- function(in_df, arg) { +#' list(data = in_df, value = arg) +#' } +#' executeFunctionInSQL(connection, foo, arg = 12345, +#' inputDataQuery = "SELECT top 1 * from airline5000") +#'} +#' +#' +#'@export +executeFunctionInSQL <- function(connectionString, func, ..., inputDataQuery = "") +{ + inputDataName <- "InputDataSet" + listArgs <- list(...) + + if (inputDataQuery != "") { + funcArgs <- methods::formalArgs(func) + if (length(funcArgs) < 1) { + stop("To use the inputDataQuery variable, the function must have at least one input argument") + } else { + inputDataName <- funcArgs[[1]] + } + } + binArgs <- serialize(listArgs, NULL) + + spees <- speesBuilderFromFunction(func = func, inputDataQuery = inputDataQuery, inputDataName = inputDataName, binArgs) + resVal <- execute(connectionString = connectionString, script = spees) + return(resVal[[1]]) +} + +#' +#'Execute a script in SQL +#' +#'@param connectionString character string. The connectionString to the database +#'@param script character string. The path to the script to execute in SQL +#'@param inputDataQuery character string. A string to query the database. +#' The result of the query will be put into a data frame into the variable "InputDataSet" in the environment +#' +#'@return The returned value from the last line of the script +#' +#'@seealso +#'\code{\link{executeFunctionInSQL}} to execute a user function instead of a script in SQL +#' +#'@export +executeScriptInSQL <- function(connectionString, script, inputDataQuery = "") +{ + + if (file.exists(script)){ + print(paste0("Script path exists, using file ", script)) + } else { + stop("Script path doesn't exist") + } + + text <- paste(readLines(script), collapse="\n") + + func <- function(InputDataSet, script) { + eval(parse(text = script)) + } + + executeFunctionInSQL(connectionString = connectionString, func = func, script = text, inputDataQuery = inputDataQuery) +} + + +#' +#'Execute a script in SQL +#' +#'@param connectionString character string. The connectionString to the database +#'@param sqlQuery character string. The query to execute +#' +#'@return The data frame returned by the query to the database +#' +#' +#'@examples +#'\dontrun{ +#' connection <- connectionInfo(database="AirlineTestDB") +#' executeSQLQuery(connection, sqlQuery="SELECT top 1 * from airline5000") +#'} +#' +#' +#'@export +executeSQLQuery <- function(connectionString, sqlQuery) +{ + #We use the serialize method here instead of OutputDataSet <- InputDataSet to preserve column names + + script <- " + serializedResult <- as.character(serialize(list(result = InputDataSet), NULL)) + OutputDataSet <- data.frame(returnVal=serializedResult)" + spees <- speesBuilder(script = script, inputDataQuery = sqlQuery, TRUE) + execute(connectionString, spees)$result +} + +# +#Execute and process a script +# +#@param connectionString character string. The connectionString to the database +#@param script character string. The script to execute +# +# +execute <- function(connectionString, script) +{ + tryCatch({ + dbhandle <- odbcDriverConnect(connectionString) + res <- sqlQuery(dbhandle, script) + if (typeof(res) == "character") { + stop(res[1]) + } + binVal <- res$returnVal + }, error = function(e) { + stop(paste0("Error in SQL Execution: ", e, "\n")) + }, finally ={ + odbcCloseAll() + }) + binVal <- res$returnVal + if (!is.null(binVal)) { + resVal <- unserialize(unlist(lapply(lapply(as.character(binVal),as.hexmode), as.raw))) + len <- length(resVal) + + # Each piece of the returned value is a different part of the output + # 1. The result of the function + # 2. The output of the function (e.g. from print()) + # 3. The warnings of the function + # 4. The errors of the function + # We raise warnings and errors, print any output, and return the actual function results to the user + + if (len > 1) { + output <- resVal[[2]] + for(o in output) { + cat(paste0(o,"\n")) + } + } + if (len > 2) { + warnings <- resVal[[3]] + for(w in warnings) { + warning(w) + } + } + if (len > 3) { + errors <- resVal[[4]] + for(e in errors) { + stop(paste0("Error in script: ", e)) + } + } + return(resVal) + } else { + return(res) + } +} + +# +#Build an R sp_execute_external_script +# +#@param script The script to execute +#@param inputDataQuery The query on the database +#@param withResults Whether to have a result set, outside of the OutputDataSet +# +speesBuilder <- function(script, inputDataQuery, withResults = FALSE) { + + resultSet <- if (withResults) "with result sets((returnVal varchar(MAX)))" else "" + + sprintf("exec sp_execute_external_script + @language = N'R', + @script = N' + %s + ', + @input_data_1 = N'%s' + %s + ", script, inputDataQuery, resultSet) +} + +# +#Build a spees call from a function +# +#@param func The function to make into a spees +#@param inputDataQuery The input data query to the database +#@param inputDataName The name of the variable to put the data frame from the query into in the script +#@param binArgs The (binary) version of all arguments passed into the function +# +#@return The spees script to execute +#The spees script will return a data frame with the results, serialized +# +speesBuilderFromFunction <- function(func, inputDataQuery, inputDataName, binArgs) { + + funcName <- deparse(substitute(func)) + funcBody <- gsub('"', '\"', paste0(deparse(func), collapse = "\n")) + + speesBody <- sprintf(" + %s <- %s + + + oldWarn <- options(\"warn\")$warn + options(warn=1) + + output <- NULL + result <- NULL + funerror <- NULL + funwarnings <- NULL + try(withCallingHandlers({ + + binArgList <- unlist(lapply(lapply(strsplit(\"%s\",\";\")[[1]], as.hexmode), as.raw)) + argList <- as.list(unserialize(binArgList)) + + if (nrow(InputDataSet)!=0) { + argList <- c(list(%s = InputDataSet), argList) + } + + funwarnings <- capture.output( + output <- capture.output( + result <- do.call(%s, argList) + ), + type=\"message\") + + }, error = function(err) { + funerror <<- err + } + ), silent = TRUE + ) + + options(warn=oldWarn) + + serializedResult <- as.character(serialize(list(result, output, funwarnings, funerror), NULL)) + OutputDataSet <- data.frame(returnVal=serializedResult) + + ", funcName, funcBody, paste0(binArgs,collapse=";"), inputDataName, funcName) + + #Call the spees builder to wrap the function; needs the returnVal resultset + speesBuilder(speesBody, inputDataQuery, withResults = TRUE) +} + diff --git a/R/R/sqlPackage.R b/R/R/sqlPackage.R new file mode 100644 index 0000000..a126b20 --- /dev/null +++ b/R/R/sqlPackage.R @@ -0,0 +1,2047 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + + +# max size in chars of the owner parameter to limit sql injection attacks +# (the owner is used in CREATE EXTERAL LIBRARY AUTHORIZATION) +MAX_OWNER_SIZE_CONST <- 128 + + +#' sql_installed.packages +#' @description Enumerates the currently installed R packages on a SQL Server for the current database +#' +#' @param connectionString ODBC connection string to Microsoft SQL Server database. +#' @param priority character vector or NULL (default). If non-null, used to select packages; "high" is equivalent to c("base", "recommended"). To select all packages without an assigned priority use priority = "NA". +#' @param noCache logical. If TRUE, do not use cached information, nor cache it. +#' @param fields a character vector giving the fields to extract from each package's DESCRIPTION file, or NULL. If NULL, the following fields are used: +#' "Package", "LibPath", "Version", "Priority", "Depends", "Imports", "LinkingTo", "Suggests", "Enhances", "License", "License_is_FOSS", "License_restricts_use", "OS_type", "MD5sum", "NeedsCompilation", and "Built". +#' Unavailable fields result in NA values. +#' @param subarch character string or NULL. If non-null and non-empty, used to select packages which are installed for that sub-architecture +#' @param scope character string which can be "private" or "public". +#' @param owner character string of a user whose private packages shall be listed (availableto dbo or db_owner users only) +#' @return matrix with enumerated packages +#' +#'@seealso{ +#'\code{\link{sql_install.packages}} to install packages +#' +#'\code{\link{sql_remove.packages}} to remove packages +#' +#'\code{\link{installed.packages}} for the base version of this function +#' +#'} +#' @export +sql_installed.packages <- function(connectionString, + priority = NULL, noCache = FALSE, fields = "Package", + subarch = NULL, scope = "private", owner = '') +{ + enumResult <- NULL + + checkOwner(owner) + checkConnectionString(connectionString) + checkVersion(connectionString) + scope <- normalizeScope(scope) + + enumResult <- list(packages = NULL, warnings = NULL, errors = NULL) + enumResult <- sqlEnumPackages( + connectionString = connectionString, + owner = owner, scope = scope, + priority = priority, fields = fields, subarch = subarch) + + if (!is.null(enumResult$errors)){ + stop(enumResult$errors, call. = FALSE) + } + + if (!is.null(enumResult$warnings)){ + warning(enumResult$warnings, immediate. = TRUE) + } + + enumResult <- enumResult$packages + + return(enumResult) +} + + +#' sql_install.packages +#' @description Installs R packages on a SQL Server database. Packages are downloaded on the client and then copied and installed to SQL Server into "public" and "private" folders. Packages in the "public" folders can be loaded by all database users running R script in SQL. Packages in the "private" folder can be loaded only by a single user. 'dbo' users always install into the "public" folder. Users who are members of the 'db_owner' role can install to both "public" and "private" folders. All other users can only install packages to their "private" folder. +#' +#' @param connectionString ODBC connection string to Microsoft SQL Server database. +#' @param pkgs character vector of the names of packages whose current versions should be downloaded from the repositories. If repos = NULL, a character vector of file paths of .zip files containing binary builds of packages. (http:// and file:// URLs are also accepted and the files will be downloaded and installed from local copies). +#' @param skipMissing logical. If TRUE, skips missing dependent packages for which otherwise an error is generated. +#' @param repos character vector, the base URL(s) of the repositories to use.Can be NULL to install from local files, directories. +#' @param verbose logical. If TRUE, more detailed information is given during installation of packages. +#' @param scope character string. Should be either "public" or "private". "public" installs the packages on per database public location on SQL server which in turn can be used (referred) by multiple different users. "private" installs the packages on per database, per user private location on SQL server which is only accessible to the single user. +#' @param owner character string. Should be either empty '' or a valid SQL database user account name. Only 'dbo' or users in 'db_owner' role for a database can specify this value to install packages on behalf of other users. A user who is member of the 'db_owner' group can set owner='dbo' to install on the "public" folder. +#' @return invisible(NULL) +#' +#'@seealso{ +#'\code{\link{sql_remove.packages}} to remove packages +#' +#'\code{\link{sql_installed.packages}} to enumerate the installed packages +#' +#'\code{\link{install.packages}} for the base version of this function +#'} +#' +#' @export +sql_install.packages <- function(connectionString, + pkgs, + skipMissing = FALSE, repos = getOption("repos"), + verbose = getOption("verbose"), scope = "private", owner = '') +{ + checkOwner(owner) + checkConnectionString(connectionString) + serverIsWindows <- checkVersion(connectionString) + sqlInstallPackagesExtLib(connectionString, + pkgs = pkgs, + skipMissing = skipMissing, repos = repos, + verbose = verbose, scope = scope, owner = owner, + serverIsWindows = serverIsWindows) + + return(invisible(NULL)) +} + +#' sql_remove.packages +#' +#' @param connectionString ODBC connection string to SQL Server database. +#' @param pkgs character vector of names of the packages to be removed. +#' @param dependencies logical. If TRUE, does dependency resolution of the packages being removed and removes the dependent packages also if the dependent packages aren't referenced by other packages outside the dependency closure. +#' @param checkReferences logical. If TRUE, verifies there are no references to the dependent packages by other packages outside the dependency closure. Use FALSE to force removal of packages even when other packages depend on it. +#' @param verbose logical. If TRUE, more detailed information is given during removal of packages. +#' @param scope character string. Should be either "public" or "private". "public" removes the packages from a per-database public location on SQL Server which in turn could have been used (referred) by multiple different users. "private" removes the packages from a per-database, per-user private location on SQL Server which is only accessible to the single user. +#' @param owner character string. Should be either empty '' or a valid SQL database user account name. Only 'dbo' or users in 'db_owner' role for a database can specify this value to remove packages on behalf of other users. A user who is member of the 'db_owner' group can set owner='dbo' to remove packages from the "public" folder. +#' @return invisible(NULL) +#' +#'@seealso{ +#'\code{\link{sql_install.packages}} to install packages +#' +#'\code{\link{sql_installed.packages}} to enumerate the installed packages +#' +#'\code{\link{remove.packages}} for the base version of this function +#' +#'} +#' +#' @export +sql_remove.packages <- function(connectionString, pkgs, dependencies = TRUE, checkReferences = TRUE, + verbose = getOption("verbose"), scope = "private", owner = '') +{ + checkOwner(owner) + checkConnectionString(connectionString) + checkVersion(connectionString) + + if (length(pkgs) == 0){ + stop("no packages provided to remove") + } + + scope <- normalizeScope(scope) + scopeint <- parseScope(scope) + + if (scope == "PUBLIC" && is.character(owner) && nchar(owner) >0) + { + stop(paste0("Invalid use of scope PUBLIC. Use scope 'PRIVATE' to remove packages for owner '", owner ,"'\n"), call. = FALSE) + } + + pkgsToDrop <- NULL # packages to drop from table and not found on file system + if ((length(pkgs) > 0) && ((dependencies == TRUE) || (checkReferences == TRUE))) + { + # + # get installed packages + # + if (verbose) + { + write(sprintf("%s Enumerating installed packages on SQL server...", pkgTime()), stdout()) + } + + installedPackages <- sql_installed.packages(connectionString = connectionString, fields = NULL, scope = scope, owner = owner) + installedPackages <- data.frame(installedPackages, row.names = NULL, stringsAsFactors = FALSE) + installedPackages <- installedPackages[installedPackages$Scope == scope,] + + # + # check for missing packages on the library paths + # + missingPackages <- pkgs[!(pkgs %in% installedPackages$Package)] + + if (length(missingPackages) > 0) + { + # we know package is not one the file-system, but it may be in the table (e.g. using create external library) and failing to install + # if a package is only in the table we still want to remove it + tablePackages <- sqlEnumTable(connectionString, missingPackages, owner, scopeint) + if (!is.null(tablePackages)) + { + missingPackages <- tablePackages[tablePackages["Found"] == 0,, drop = FALSE] + } + + if (nrow(missingPackages) > 0) + { + stop(sprintf("Cannot find specified packages (%s) to remove from scope '%s'", paste(missingPackages$Package, collapse = ', '), scope), call. = FALSE) + } + + pkgsToDrop <- tablePackages$Package + pkgs <- pkgs[pkgs %in% installedPackages$Package] + } + + # + # get the dependent list of packages which is safe to remove + # + pkgsToUninstall <- getDependentPackagesToUninstall(pkgs, installedPackages = installedPackages, + dependencies = dependencies, checkReferences = checkReferences, verbose = verbose) + + if (is.null(pkgsToUninstall)) + { + pkgs <- NULL + } + else + { + pkgs <- pkgsToUninstall$Package + } + } + + if (length(pkgs) > 0 || length(pkgsToDrop) > 0) + { + if (verbose) + { + write(sprintf("%s Uninstalling packages on SQL server (%s)...", pkgTime(), paste(c(pkgs, pkgsToDrop), collapse = ', ')), stdout()) + } + + sqlHelperRemovePackages(connectionString, pkgs, pkgsToDrop, scope, owner) + } + return(invisible(NULL)) +} + + +# +# Executes a R function on a remote sql server +# using sp_execute_external_script +# This is a variation on execute in sql, with a few extra params +# +# @param connection odbc connection string or a valid RODBC handle +# +# @param FUN function to execute +# @param ... parameters passed to FUN +# +# @param useRemoteFun by default inserts function definition as available on the client as text into sp_execute_external_script +# if TRUE uses function as available on the remote server +# +# @param asuser calls sp_execute_external_script with EXECUTE AS USER 'asuser' +# +# @return data frame returned by FUN +# +sqlRemoteExecuteFun <- function(connection, FUN, ..., useRemoteFun = FALSE, asuser = NULL, includeFun = list()) +{ + if (class(connection) == "character"){ + if (nchar(connection) < 1){ + stop(paste0("Invalid connection string: ", connection), call. = FALSE) + } + } else if (class(connection) != "RODBC"){ + stop("Invalid connection string has to be a character string or RODBC handle", call. = FALSE) + } + + if (is.character(asuser) && length(asuser) == 1){ + if (nchar(asuser) == 0){ + asuser <- NULL + } + } else { + asuser <- NULL + } + + # input processing and checking + if (is.function(FUN)) { + funName <- deparse(substitute(FUN)) + } else { + if (!is.character(FUN)) + stop(paste("you must provide either a function object or a function")) + funName <- FUN + FUN <- match.fun(FUN) + } + + # + # captures the R code and formats it to be embedded in t-sql sp_execute_external_script + # + deparseforSql <- function(funName, fun) + { + # counts the number of spaces at the beginning of the string + countSpacesAtBegin <- function(s) { + p <- gregexpr("^ *", s) + return(attr(p[[1]], "match.length")) + } + + funBody <- deparse(fun) + + # add on the function definititon + funBody[1] <- paste(funName, "<-", funBody[1], sep = " ") + + # escape single quotes and get rid of tabs + funBody <- sapply(funBody, gsub, pattern = "\t", replacement = " ") + funBody <- sapply(funBody, gsub, pattern = "'", replacement = "''") + + # handle the case where the function's rcode was indented + # more than 2 spaces and get rid of extra spaces. + # otherwsise the resulting indentation of R code in TSQL + # will depend on the indentation of the code in the R + if (length(funBody) > 1) + { + # temporarily discard empty lines so they don't affect space counting + no_empty_lines <- funBody[funBody != ""] + + # remove the first line (function declaration line) from no_empty_lines + # and if the last line only contains a closing bracket align it with + # the function declaration and remove as well + if (grepl("^ *} *$", funBody[length(funBody)])) { + funBody[length(funBody)] <- "}" + no_empty_lines <- no_empty_lines[2:(length(no_empty_lines) - 1)] + } else { + no_empty_lines <- no_empty_lines[2:length(no_empty_lines)] + } + + # find the minimum number of extra spaces + extra_spaces <- min(sapply(no_empty_lines, countSpacesAtBegin)) - 2 + + # remove extra spaces + if (extra_spaces > 0) { + for (i in 2:(length(funBody) - 1)) { + funBody[i] <- gsub(paste("^ {", extra_spaces,"}", sep = ""), + "", funBody[i]) + } + } + } + + funText <- paste(funBody, collapse = "\n") + + return (funText) + } + + + # Define a function that will attempt to resolve the ellipsis arguments + # passed into the rxElem function and return those elements in a (named) list. + # For those elements that are not resolvable, leave them as promises to be + # evaluated on the cluster. This scheme avoids, for example, the need to have + # a particular package loaded locally in order to (locally) resolve symbols/data + # that belong to that package. In this case, the packagesToLoad argument is expected + # to name the package that is required to be loaded on the cluster nodes in order for + # the promised symbols to be resolvable. + tryEvalArgList <- function(...) + { + # Convert ellipsis arguments into a list of substituted values, + # which will result in names, symbols, or language objects and + # will avoid the evaluation. + argListSubstitute <- as.list(substitute(list(...)))[-1L] + + # Now attempt to evaluate each argument. If we fail, then keep + # argument value as a substituted value. These substituted values + # essentially act as a promise and will be evaluated on the cluster. + # If they also fail (re not resolvable) on the cluster, an error will + # be returned. + envir <- parent.frame(n = 2) + sapply(argListSubstitute, function(x, envir) + { + res <- try(eval(x, envir = envir), silent = TRUE) + if (!inherits(res, "try-error")) res else x + }, envir = envir, simplify = FALSE) + } + + argList <- tryEvalArgList(...) + binArgList <- serialize(argList, NULL) + binArgListCollapse <- paste0(binArgList, collapse = ";") + + script <- "" + + if (length(includeFun) > 0) + { + includeFunNames <- names(includeFun) + for (i in seq_along(includeFun)) + { + script <- paste0(script, "\n", deparseforSql(includeFunNames[[i]], includeFun[[i]])) + } + } + + if (!useRemoteFun) + { + funText <- deparseforSql(funName, FUN) + script <- paste0(script, "\n", funText) + } + + + script <- paste0(script, + sprintf(" + result <- NULL + funerror <- NULL + funwarnings <- NULL + output <- capture.output(try( + withCallingHandlers({ + binArgList <- unlist(lapply(lapply(strsplit(\"%s\",\";\")[[1]], as.hexmode), as.raw)) + argList <- as.list(unserialize(binArgList)) + result <- do.call(%s, argList) + }, error = function(err) { + funerror <<- err + }, warning = function(warn) { + funwarnings <<- c(funwarnings, warn$message) + } + ), silent = TRUE + )) + serializedResult <- as.character(serialize(list(result, funerror, funwarnings, output), NULL)) + outputDataFrame <- data.frame(serializedResult, stringAsFactors=FALSE) + ", binArgListCollapse, funName) + ) + + query <- "" + if (!is.null(asuser)){ + query <- paste0("EXECUTE AS USER = '", asuser, "';") + } + + query <- paste0(query + ,"\nEXEC sp_execute_external_script" + ,"\n@language = N'R'" + ,"\n,@script = N'",script, "'" + ,"\n, @output_data_1_name = N'outputDataFrame';" + ) + + if (!is.null(asuser)){ + query <- paste0(query, "\nREVERT;") + } + + success <- FALSE + error <- "" + tryCatch({ + if (class(connection) == "character"){ + hodbc <- odbcDriverConnect(connection) + if (hodbc == -1){ + error <- sprintf("failed to connect to sql server using connection string %s", connection) + success <- FALSE + } + } else { + hodbc <- connection + } + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + if (is.data.frame(sqlResult)){ + serializedResult <- sqlResult[[1]] + success <- TRUE + } else { + # error happened, vector of string contains error messages + error <- paste(sqlResult, sep = "\n") + success <- FALSE + } + }, error = function(err) { + success <- FALSE + error <- err$message + }, finally = { + if (class(connection) == "character" && hodbc != -1){ + odbcClose(hodbc) + } + }) + + if (success) + { + lst <- unserialize(unlist(lapply(lapply(as.character(serializedResult),as.hexmode), as.raw))) + + result <- lst[[1]] + funerror <- lst[[2]] + funwarnings <-lst[[3]] + output <- lst[[4]] + + if (!is.null(output)){ + for(o in output) { + cat(paste0(o,"\n")) + } + } + + if (!is.null(funwarnings)){ + for(w in funwarnings) { + warning(w, call. = FALSE) + } + } + + if (!is.null(funerror)){ + stop(funerror, call. = FALSE) + } + + return(result) + } + else + { + stop(error, call. = FALSE) + } +} + +checkOwner <- function(owner) +{ + if (!is.null(owner)) + { + if (is.character(owner) && length(owner) == 1 && nchar(owner) <= MAX_OWNER_SIZE_CONST) + { + invisible(NULL) + } + else + { + stop(paste0("Invalid value for owner: ", owner ,"\n"), call. = FALSE) + } + } +} + +getPackageTopMostAttributeFlag <- function() +{ + 0x1 +} + +pkgTime <- function() +{ + # tz: ""= current time zone, "GMT" = UTC + return (format(Sys.time(), "%Y-%m-%d %H:%M:%OS2", tz = "")) +} + +checkConnectionString <- function(connectionString) +{ + if (!is.null(connectionString) && is.character(connectionString) && length(connectionString) == 1 && nchar(connectionString) > 0) + { + invisible(NULL) + } + else + { + stop(paste0("Invalid connection string: ", connectionString ,"\n"), call. = FALSE) + } +} + +checkOdbcHandle <- function(hodbc, connectionString) +{ + if (hodbc == -1){ + stop(sprintf("Failed to connect to sql server using connection string %s", connectionString, call. = FALSE)) + } + invisible(NULL) +} + +checkResult <- function( result, expectedResult, errorMessage) +{ + if (result != expectedResult){ + stop(errorMessage, call. = FALSE) + } + invisible(NULL) +} + +# +# Removes fields if requested +# +processInstalledPackagesResult <- function(result, fields) +{ + if (!is.null(fields) && is.character(fields)) + { + result <- result[, fields, drop = FALSE] + } + + if ((!is.null(fields)) && ((fields == "Package") && is.null(dim(result)))) + { + names(result) <- NULL + } + return(result) +} + + +# +# Returns +# normalized string for string cope +# scope input for anything else +# +normalizeScope <- function(scope) +{ + scopes <- c("PUBLIC", "PRIVATE", "SYSTEM") + if (is.character(scope) && length(scope) == 1) + { + normScope <- toupper(scope) + if (normScope == "SHARED"){ + normScope <- "PUBLIC" + } + if (normScope %in% scopes){ + return (normScope) + } + } + + stop(sprintf("Invalid scope argument value: %s", scope), call. = FALSE) +} + +# +# Parses scope which can be an integer or string +# returns +# 0 for PUBLIC / SHARED +# 1 for PRIVATE +# PUBLIC for 0 +# PRIVATE for 1 +# +parseScope <- function(scope) +{ + + scopes <- c(0L, 1L, 0L) + names(scopes) <- c("PUBLIC", "PRIVATE", "SHARED") + + if ((is.integer(scope) || is.numeric(scope)) && (scope%%1==0)) + { + if ((scope >= 0L) && (scope <= 1L)) + { + scopeIndex <- scope + 1L + parsedScope <- names(scopes)[scopeIndex] + } + else + { + stop("Invalid scope argument value.", call. = FALSE) + } + } + else if (is.character(scope) && length(scope) == 1 && toupper(scope) %in% names(scopes)) + { + parsedScope <- scopes[[toupper(scope)]] + } + else + { + stop("Invalid scope argument value.", call. = FALSE) + } + + if (is.na(parsedScope)) + { + stop("Invalid scope argument value.", call. = FALSE) + } + + parsedScope +} + +# +# Returns TRUE if server is windows +# +sqlIsServerWindows <- function(connectionString) +{ + checkConnectionString(connectionString) + + isWindows <- function() + { + return (Sys.info()['sysname'] == 'Windows'); + } + + isWindowsResult <- sqlRemoteExecuteFun(connectionString, isWindows) + + return(isWindowsResult) +} + +# +# Returns current sql user (the result of SELECT USER query) +# Returns NULL if query failed +# +sqlSelectUser <- function(connectionString) +{ + user <- "" + query <- "SELECT USER;" + + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + on.exit(odbcClose(hodbc), add = TRUE) + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + + if (is.data.frame(sqlResult)) + { + user <- sqlResult[1,1] + } + else + { + user <- NULL + } + + return (user) +} + +# +# Checks if sql server supports package management based on create external library +# Returns TRUE if server is Windows +# +checkVersion <- function(connectionString) +{ + serverIsWindows <- sqlIsServerWindows(connectionString) + if(!serverIsWindows){ + stop("Package management currently not supported on Linux SQL Server.", call. = FALSE) + } + + versionClass <- sqlCheckPackageManagementVersion(connectionString) + + if (is.character(versionClass) && versionClass == "ExtLib"){ + return (serverIsWindows) + } else { + stop(paste0("SQL server does not support package management."), call. = FALSE) + } +} + + +# +# Checks if sql server version supports package management +# Returns "ExtLib" if is supports external library ddl +# +# We support SQL Azure DB +# +# SQL Azure 12.0.2000.8 +# SQL Server 2017 14.0.1000.169 +# +# Note: +# Older version os SQL server are support by the legacy +# package management APIs in RevoScaleR: +# +# SQL Server 2016 SP1 13.0.4001.0 +# SQL Server 2016 13.0.1601.5 +# +#' @importFrom utils tail +sqlCheckPackageManagementVersion <- function(connectionString) +{ + versionClass <- NA + force(connectionString) + + if(is.null(connectionString) || nchar(connectionString) == 0){ + stop("Invalid connectionString is null or empty") + } + + version <- sqlPackageManagementVersion(connectionString) + + if (is.null(version) || is.na(version) || length(version) == 0) + { + stop("Invalid SQL version is null or empty", call. = FALSE) + } + + if( ( (version[[1]]=="azure" && version[[2]] >= 12 ) || (version[[1]]=="box" && version[[2]] >= 15 ))) + { + # server supports external library with DDLs + versionClass <- "ExtLib" + } + else + { + stop(sprintf("The package management feature is not enabled for the current user or not supported on SQL Server version %s", paste(tail(version, -1), collapse='.')), call. = FALSE) + } + + return(versionClass) +} + +# +# Returns a list with the "azure" or "box" for the first element and the product version is the remaining elements +# +# Examples: +# list("azure", 12, 0, 2000, 8) +# list("box", 15, 0, 400, 107) +# +#' @importFrom utils tail +sqlPackageManagementVersion <- function(connectionString) +{ + force(connectionString) + + pmversion <- NULL + + serverProperties <- sqlServerProperties(connectionString) + if (is.null(serverProperties)){ + stop(sprintf("Failed to get SQL version using connection string '%s'", connectionString ), call. = FALSE) + } + + if(serverProperties[[1]] == "sql azure" && serverProperties[[2]]==5) + { + # sql azure + pmversion <- append(list("azure"), tail(serverProperties, -2)) + } + else + { + # sql box product + pmversion <- append(list("box"), tail(serverProperties, -2)) + } + + return (pmversion) +} + +# +# Returns a list with the Edition, EngineEdition and product version as an integer vector +# NULL if it failed +# Strings will lowercased +# Examples: list("sql azure", 5, 12, 0, 2000, 8) +# Examples: list("enterprise edition (64-bit)", 3, 15, 0, 400, 107) +# +# References: https://docs.microsoft.com/en-us/sql/t-sql/functions/serverproperty-transact-sql +# https://technet.microsoft.com/en-us/library/ms174396(v=sql.110).aspx +# http://www.sqlservercentral.com/blogs/gorandalfs-sql-blog/2015/06/10/azure-sql-database-version-and-compatibility-level/ +# +sqlServerProperties <- function(connectionString) +{ + serverProperties <- NULL + + query <- paste0("SELECT CAST(SERVERPROPERTY('Edition') AS nvarchar) AS Edition, CAST(SERVERPROPERTY('EngineEdition') AS nvarchar) AS EngineEdition, CAST(SERVERPROPERTY('ProductVersion') AS nvarchar) AS ProductVersion") + + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + on.exit(odbcClose(hodbc), add = TRUE) + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + + if (is.data.frame(sqlResult)) + { + # + # Edition EngineEdition ProductVersion + #1 Enterprise Edition (64-bit) 3 15.0.800.91 + # + serverProperties <- list(Edition= tolower(sqlResult$Edition), EngineEdition=as.integer(sqlResult$EngineEdition)) + productVersion <- as.integer(unlist(strsplit(sqlResult$ProductVersion, "\\."))) + serverProperties <- append(serverProperties, productVersion) + } + return (serverProperties) +} + +# +# Returns list containing matrix with installed packages, warnings and errors +# +sqlEnumPackages <- function(connectionString, owner, scope, priority, fields, subarch) +{ + result <- list(packages = NULL, warnings = NULL, errors = NULL) + + scopeint <- parseScope(scope) + + pkgGetLibraryPath <- function(scopeint) + { + if (!all.equal(scopeint,as.integer(scopeint))){ + stop("pkgGetLibraryPathExtLib(): scope expected to be an integer", call. = FALSE) + } + + if (scopeint == 0){ + extLibPath <- Sys.getenv("MRS_EXTLIB_SHARED_PATH") + } else if (scopeint == 1){ + extLibPath <- Sys.getenv("MRS_EXTLIB_USER_PATH") + } else { + stop(paste0("pkgGetLibraryPathExtLib(): invalid scope value ", scopeint, ""), call. = FALSE) + } + + extLibPath <- normalizePath(extLibPath, mustWork = FALSE) + extLibPath <- gsub('\\\\', '/', extLibPath) + + return(extLibPath) + } + + # + # Returns PRIVATE, PUBLIC and SYSTEM library paths in a data frame in this order + # + sqlGetScopeLibraryPaths <- function(connectionString) + { + getScopeLibraryPaths <- function() + { + publicPath <- try(pkgGetLibraryPath(0), silent = TRUE) + if (inherits(publicPath, "try-error")) + { + publicPath <- NA + } + + privatePath <- try(pkgGetLibraryPath(1), silent = TRUE) + if (inherits(privatePath, "try-error")) + { + privatePath <- NA + } + + systemPath <- .Library + + scopes <- c("PRIVATE", "PUBLIC", "SYSTEM") + + return (data.frame(Scope = scopes, Path = c(privatePath, publicPath, systemPath), row.names = scopes, stringsAsFactors = FALSE)) + } + + libPaths <- sqlRemoteExecuteFun(connectionString, getScopeLibraryPaths, asuser = owner, includeFun = list(pkgGetLibraryPath = pkgGetLibraryPath)) + + + return(libPaths) + } + + # + # Appends installed packages for a specific scope & library path + # + addInstalledPackages <- function(connectionString, installedPackages = NULL, libScope, libPath, priority = NULL, fields = "Package", subarch = NULL) + { + result <- list(installedPackages = NULL, warnings = NULL, errors = NULL) + + # + # Returns data frame will list of all packages and their 'isTopLevel' attribute for given owner and scope + # If attribute 'isTopLevel' is not set for a package it will be -1 + # + sqlQueryIsTopPackageExtLib <- function(connectionString, packagesNames, owner, scope) + { + scopeint <- parseScope(scope) + + result <- enumerateTopPackages( + connectionString = connectionString, + packages = packagesNames, + owner = owner, + scope = scopeint) + + if (is.null(result) || nrow(result)<1) + { + return(NULL) + } + else if (is.data.frame(result)) + { + rownames(result) <- result$name + return (result) + } + } + + # enumerate packages installed under sql server R library path + packages <- NULL + tryCatch({ + packages <- sqlRemoteExecuteFun(connectionString, utils::installed.packages, lib.loc = libPath, noCache = TRUE, + priority = priority, fields = NULL, subarch = subarch, + useRemoteFun = TRUE, asuser = owner) + }, + error = function(err){ + stop(paste0("failed to enumerate installed packages on file system: ", err$message), call. = FALSE) + } + ) + + if (!is.null(packages) && nrow(packages)>0) + { + packages <- cbind(packages, Attributes = rep(NA, nrow(packages)), Scope = rep(libScope, nrow(packages))) + + # get top package flag if attributes column will be present in final results and if we are in PUBLIC or PRIVATE scope + if (nrow(packages) > 0 && (libScope == 'PUBLIC' || libScope == 'PRIVATE')) + { + filteredPackages <- processInstalledPackagesResult(packages, fields) + if ('Attributes' %in% colnames(filteredPackages)) + { + packagesNames <- rownames(packages[packages[,'Scope'] == libScope,, drop = FALSE]) + + if (length(packagesNames) > 0) + { + isTopPackageDf<-sqlQueryIsTopPackageExtLib(connectionString, packagesNames, owner, libScope) + + if (!is.null(isTopPackageDf)) + { + for(pkg in packagesNames) + { + if (packages[pkg,'Scope'] == libScope) + { + isTopPackage <- as.integer(isTopPackageDf[pkg,'IsTopPackage']) + if (isTopPackage == -1){ + isTopPackage = 1 + result$warnings <- c(result$warnings, sprintf("missing attribute for package set as top level: (%s)", pkg)) + } + packages[pkg,'Attributes'] <- isTopPackage + } + } + } + } + } + } + + + if (is.null(installedPackages)) + { + installedPackages <- packages + } + else + { + installedPackages <- rbind(installedPackages, packages) + } + } + + result$installedPackages <- installedPackages + + return(result) + } + + extLibPaths <- sqlGetScopeLibraryPaths(connectionString) + + installedPackages <- NULL + for(i in 1:nrow(extLibPaths)) + { + libPath <- extLibPaths[i, "Path"] + + if (!is.na(libPath)) + { + libScope <- extLibPaths[i, "Scope"] + + ret <- NULL + if (libScope == "PRIVATE") + { + if (scope == "PRIVATE") + { + ret <- addInstalledPackages(connectionString, installedPackages, libScope, libPath, priority, fields, subarch) + } + } + else + { + ret <- addInstalledPackages(connectionString, installedPackages, libScope, libPath, priority, fields, subarch) + } + if (!is.null(ret)){ + installedPackages <- ret$installedPackages + result$warnings <- c(result$warnings,ret$warnings) + result$errors <- c(result$errors,ret$errors) + } + } + } + + installedPackages <- processInstalledPackagesResult(installedPackages, fields) + + result$packages <- installedPackages + + return(result) +} + +getDependentPackagesToInstall <- function(pkgs, availablePackages, installedPackages, skipMissing = FALSE, + verbose = getOption("verbose")) +{ + # + # prune requested packages to exclude base packages + # + basePackages <- installedPackages[installedPackages[,"Priority"] %in% c("base", "recommended"), c("Package", "Priority"), drop = FALSE]$Package + droppedPackages <- pkgs[pkgs %in% basePackages] + + if (length(droppedPackages) > 0) + { + warning(sprintf("Skipping base packages (%s)", paste(droppedPackages, collapse = ', '))) + } + + pkgs <- pkgs[!(pkgs %in% droppedPackages)] + + if (length(pkgs) < 1) + { + return (NULL) + } + + # + # get dependency closure for all given packages + # note: by default we obtain a package+dependencies from one CRAN which should have versions that work together without conflicts. + # + if (verbose) + { + write(sprintf("%s Resolving package dependencies for (%s)...", pkgTime(), paste(pkgs, collapse = ', ')), stdout()) + } + + dependencies <- tools::package_dependencies(packages = pkgs, db = availablePackages, recursive = TRUE, verbose = FALSE) + + # + # get combined dependency closure w/o base packages + # + dependencies <- unique(unlist(c(dependencies, names(dependencies)), recursive = FALSE, use.names = FALSE)) + dependencies <- dependencies[dependencies != "NA"] + dependencies <- dependencies[!(dependencies %in% basePackages)] + + if (length(dependencies) < 1) + { + return (NULL) + } + + # + # are there any missing packages in dependency closure? + # + availablePackageNames <- rownames(availablePackages) + missingPackages <- dependencies[!(dependencies %in% availablePackageNames)] + + if (length(missingPackages) > 0) + { + missingPackagesStr <- sprintf("Missing dependency packages (%s)", paste(missingPackages, collapse = ', ')) + + if (!skipMissing) + { + stop(missingPackagesStr, call. = FALSE) + } + else + { + warning(missingPackagesStr) + } + } + + # + # get the packages in order of dependency closure + # + dependencies <- unique(dependencies) + pkgsToInstall <- availablePackages[match(dependencies, availablePackageNames),] + pkgsToInstall <- pkgsToInstall[!is.na(pkgsToInstall$Package),] + + return (pkgsToInstall) +} + +# +# Returns list with 2 data frames. +# First data frames containes pruned packages to install +# Second data frame contains pruned packages to mark as top-level +# +prunePackagesToInstallExtLib <- function(dependentPackages, topMostPackages, installedPackages, verbose = getOption("verbose")) +{ + prunedPackagesToInstall <- NULL + prunedPackagesToTop <- NULL + + for (pkgToInstallIndex in 1:nrow(dependentPackages)) + { + pkgToInstall <- dependentPackages[pkgToInstallIndex,] + + # get available packages that match the name of the package we depend on + availablePkgs <- installedPackages[match(pkgToInstall$Package, installedPackages$Package, nomatch = 0),, drop = FALSE] + + + if (nrow(availablePkgs) == 0) + { + # no packages available, add packages we depend to the list of pruned packages to install + prunedPackagesToInstall <- rbind(prunedPackagesToInstall, pkgToInstall) + } + else + { + # If a package A is installed that depends on B and B is already installed, 3 scenarios are possible: + # (1) versions are the same -> OK + # (2) installed version is newer -> OK + # (3) installed version is older -> we print a warning to allow user to make proper decision + for(scope in c("PRIVATE", "PUBLIC", "SYSTEM")) + { + availablePkg <- availablePkgs[ availablePkgs$Scope == scope,, drop = FALSE ] + if (nrow(availablePkg) == 1){ + if (utils::compareVersion(availablePkg$Version, pkgToInstall$Version) == -1){ + #pkgToInstall is newer (later) than availablePkg + warning(sprintf("package is already installed but version is older than available in repos: package='%s', scope='%s', currently installed version='%s', new version=='%s'", pkgToInstall$Package, scope, availablePkg$Version, pkgToInstall$Version), call. = FALSE) + } + break + } + } + + # if the available package is being requested as a top-level package we check + # if the top-leve attribute on the package is set to false we will have to update it to true + if ('Attributes' %in% colnames(installedPackages)){ + if (pkgToInstall$Package %in% topMostPackages){ # package to install is requested as top-level + # if package is marked as depended we have to set it as top-level + pkgToTop <- availablePkgs[!is.na(availablePkgs[,'Attributes']) & + bitwAnd(as.integer(availablePkgs[,'Attributes']), getPackageTopMostAttributeFlag()) == 0 + ,, drop = FALSE] + if (nrow(pkgToTop) > 0) + { + prunedPackagesToTop <- rbind(prunedPackagesToTop, pkgToTop) + } + } + } + } + } + + return (list(prunedPackagesToInstall, prunedPackagesToTop)) +} + +downloadDependentPackages <- function(pkgs, destdir, binaryPackages, sourcePackages, + verbose = getOption("verbose"), binaryType = "win.binary") +{ + downloadedPkgs <- NULL + numPkgs <- nrow(pkgs) + + for (pkgIndex in 1:numPkgs) + { + pkg = pkgs[pkgIndex,] + + if (verbose) + { + write(sprintf("%s Downloading package [%d/%d] %s (%s)...", pkgTime(), pkgIndex, numPkgs, pkg$Package, pkg$Version), stdout()) + } + + # + # try first binary package + # + downloadedPkg <- utils::download.packages(pkg$Package, destdir = destdir, + available = binaryPackages, type = binaryType, quiet = TRUE) + + if (length(downloadedPkg) < 1) + { + # + # try source package if binary package isn't there + # + downloadedPkg <- utils::download.packages(pkg$Package, destdir = destdir, + available = sourcePackages, type = "source", quiet = TRUE) + } + + if (length(downloadedPkg) < 1) + { + stop(sprintf("Failed to download package %s.", pkg$Package), call. = FALSE) + } + + downloadedPkg[1,2] <- normalizePath(downloadedPkg[1,2], mustWork = FALSE) + downloadedPkgs <- rbind(downloadedPkgs, downloadedPkg) + } + + downloadedPkgs <- data.frame(downloadedPkgs, stringsAsFactors = FALSE) + colnames(downloadedPkgs) <- c("Package", "File") + rownames(downloadedPkgs) <- downloadedPkgs$Package + + return (downloadedPkgs) +} + + +# +# Installs packages using external library ddl support +# +sqlInstallPackagesExtLib <- function(connectionString, + pkgs, + skipMissing = FALSE, repos = getOption("repos"), verbose = getOption("verbose"), + scope = "private", owner = '', + serverIsWindows = TRUE) +{ + # + # check permissions + # + checkPermission <- function(connectionString, scope, owner, verbose) + { + sqlCheckPermission <- function(connectionString, scope, owner) + { + allowed <- FALSE + + haveOwner <- (nchar(owner) > 0) + query <- "" + + if (haveOwner){ + query <- paste0("EXECUTE AS USER = '", owner , "';\n") + } + + query <- paste0(query, "SELECT USER;") + + if (haveOwner) { + query <- paste0(query, "\nREVERT;") + } + + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + on.exit(odbcClose(hodbc), add = TRUE) + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + + + if (is.data.frame(sqlResult)) + { + user <- sqlResult[1,1] + + if (user == '') + { + allowed <- FALSE + } + else if (scope == 1 && user == "dbo") + { + # block dbo call to install into PRIVATE lib path which is not supported by create external library + allowed <- FALSE + } + else + { + allowed <- TRUE + } + } + else + { + #cannot execute as the database principal because the principal "uga" does not exist + allowed <- FALSE + } + + return (allowed) + } + + if (verbose) + { + write(sprintf("%s Verifying permissions to install packages on SQL server...", pkgTime()), stdout()) + } + + if (scope == "PUBLIC") + { + if (is.character(owner) && nchar(owner) >0) + { + stop(paste0("Invalid use of scope PUBLIC. Use scope 'PRIVATE' to install packages for owner '", owner ,"'\n"), call. = FALSE) + } + } + else if (scope == "PRIVATE") + { + # fail dbo calls to install to private scope as dbo can only install to public + scopeint <- parseScope(scope) + allowed <- sqlCheckPermission(connectionString, scope, owner) + + if (!allowed) + { + stop(sprintf("Permission denied for installing packages on SQL server for current user: scope='%s', owner='%s'.", scope, owner), call. = FALSE) + } + } + } + + attributePackages <- function(connectionString, packages, scopeint, owner, verbose) + { + packagesNames <- sapply(packages, function(pkg){pkg$name},USE.NAMES = FALSE) + + if (verbose) + { + write(sprintf("%s Attributing packages on SQL server (%s)...", pkgTime(), paste(packagesNames, collapse = ', ')), stdout()) + } + + result <- sqlMakeTopLevel(connectionString = connectionString, + packages = packagesNames, + owner = owner, + scope = as.integer(scopeint)) + + if (result) { + write(sprintf("Successfully attributed packages on SQL server (%s).", + paste(packagesNames, collapse = ', ')), stdout()) + } + } + + # check scope and permission to write to scoped folder + scope <- normalizeScope(scope) + scopeint <- parseScope(scope) + + checkPermission(connectionString, scope, owner, verbose) + + topMostPackageFlag <- getPackageTopMostAttributeFlag() + + if (length(pkgs) > 0) + { + downloadDir <- tempfile("download") + dir.create(downloadDir) + on.exit(unlink(downloadDir, recursive = TRUE), add = TRUE) + + packages <- list() + + if (length(repos) > 0) + { + # + # get the available package lists + # + sourcePackages <- utils::available.packages(utils::contrib.url(repos = repos, type = "source"), type = "source") + row.names(sourcePackages) <- NULL + binaryPackages <- if (serverIsWindows) utils::available.packages(utils::contrib.url(repos = repos, type = "win.binary"), type = "win.binary") else NULL + row.names(binaryPackages) <- NULL + pkgsUnison <- data.frame(rbind(sourcePackages, binaryPackages), stringsAsFactors = FALSE) + pkgsUnison <- subset(pkgsUnison, !duplicated(Package)) + row.names(pkgsUnison) <- pkgsUnison$Package + + # + # check for missing packages + # + missingPkgs <- pkgs[!(pkgs %in% pkgsUnison$Package) ] + + if (length(missingPkgs) > 0) + { + stop(sprintf("Cannot find specified packages (%s) to install", paste(missingPkgs, collapse = ', ')), call. = FALSE) + } + + # + # get all installed packages + # + installedPackages <- sql_installed.packages(connectionString, fields = NULL, scope = scope, owner = owner) + installedPackages <- data.frame(installedPackages, row.names = NULL, stringsAsFactors = FALSE) + + # + # get dependency closure of given packages + # + pkgsToDownload <- getDependentPackagesToInstall(pkgs = pkgs, availablePackages = pkgsUnison, + installedPackages = installedPackages, + skipMissing = skipMissing, verbose = verbose) + + # + # prune dependencies for already installed packages + # + prunedPkgs <- prunePackagesToInstallExtLib(dependentPackages = pkgsToDownload, + topMostPackages = pkgs, + installedPackages = installedPackages, verbose = verbose) + pkgsToDownload <- prunedPkgs[[1]] + pkgsToAttribute <- prunedPkgs[[2]] + + + if (length(pkgsToDownload) < 1 && length(pkgsToAttribute) < 1) + { + write(sprintf("Packages (%s) are already installed.", paste(pkgs, collapse = ', ')), stdout()) + + return (invisible(NULL)) + } + + if (length(pkgsToDownload) > 0) + { + # + # download all the packages in dependency closure + # + downloadPkgs <- downloadDependentPackages(pkgs = pkgsToDownload, destdir = downloadDir, + binaryPackages = binaryPackages, sourcePackages = sourcePackages, verbose = verbose) + } + + if (length(pkgsToDownload) > 0) + { + attributesVec<-apply(downloadPkgs, 1, function(x){ + packageAttributes <- 0x0 + if (x["Package"] %in% pkgs){ + packageAttributes <- bitwOr(packageAttributes,topMostPackageFlag) + } + return (packageAttributes) + } + ) + + downloadPkgs <- cbind(downloadPkgs, Attribute = attributesVec) + sqlHelperInstallPackages(connectionString, downloadPkgs, owner, scope, verbose) + + } + + if (length(pkgsToAttribute) > 0) + { + for (packageIndex in 1:nrow(pkgsToAttribute)) + { + packageDescriptor <- list() + packageDescriptor$name <- pkgsToAttribute[packageIndex,"Package"] + packageAttributes <- 0x0 + if (packageDescriptor$name %in% pkgs){ + packageAttributes <- bitwOr(packageAttributes,topMostPackageFlag) + } + packageDescriptor$attributes <- packageAttributes + + + packages[[length(packages) + 1]] <- packageDescriptor + } + + attributePackages(connectionString, packages, scopeint, owner, verbose) + } + + } + else + { + # no repos provided, packages are file paths + pkgs <- normalizePath(pkgs, mustWork = FALSE) + missingPkgs <- pkgs[!file.exists(pkgs)] + + if (length(missingPkgs) > 0) + { + stop(sprintf("%s packages are missing.", paste0(missingPkgs, collapse = ", ")), call. = FALSE) + } + + packages <- data.frame(matrix(nrow = 0, ncol = 3), stringsAsFactors = FALSE) + for( packageFile in pkgs){ + packages <- rbind(packages, data.frame( + Package = unlist(lapply(strsplit(basename(packageFile), '\\.|_'), '[[', 1), use.names = F), + File = packageFile, + Attribute = topMostPackageFlag, + stringsAsFactors = FALSE)) + } + + sqlHelperInstallPackages(connectionString, packages, owner, scope, verbose) + } + + } +} + +# +# Calls CREATE EXTERNAL LIBRARY on a package +# +sqlCreateExternalLibrary <- function(hodbc, packageName, packageFile, user = "") +{ + # read zip file into binary format + fileConnection <- file(packageFile, 'rb') + pkgBin <- readBin(con = fileConnection, what = raw(), n = file.size(packageFile)) + close(fileConnection) + pkgContent = paste0("0x", paste0(pkgBin, collapse = "")); + + + haveUser <- (user != '') + + query <- paste0("CREATE EXTERNAL LIBRARY [", packageName, "]") + + if (haveUser){ + query <- paste0(query, " AUTHORIZATION ", user) + } + + query <- paste0(query, " FROM (CONTENT=", pkgContent ,") WITH (LANGUAGE = 'R');") + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + + if (is.character(sqlResult)){ + return (TRUE) + } + + stop(paste(sqlResult, sep = "\n")) +} + +# +# Calls DROP EXTERNAL LIBRARY on a package +# +sqlDropExternalLibrary <- function(hodbc, packageName, user = "") +{ + haveUser <- (user != '') + + query <- paste0("DROP EXTERNAL LIBRARY [", packageName, "]") + + if (haveUser){ + query <- paste0(query, " AUTHORIZATION ", user, ";") + } + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + + if (is.character(sqlResult)){ + return (TRUE) + } + + stop(paste(sqlResult, sep = "\n")) +} + +# +# Adds extendend property to package to store attributes (Top level package) +# +sqlAddExtendedProperty <- function(hodbc, packageName, attributes, user = "") +{ + isTopLevel <- attributes & 0x1; + + haveUser <- (user != '') + + + # use extended property to set top level packages + if (haveUser){ + # if we have an user bind it to the query + query <- paste0("EXEC sp_addextendedproperty @name = N'IsTopPackage', @value=", isTopLevel,", @level0type=N'USER', @level0name=",user,", @level1type = N'external library', @level1name =", packageName) + } else { + # if user is empty we use the current user + query <- paste0("DECLARE @currentUser NVARCHAR(128); SELECT @currentUser = CURRENT_USER; EXEC sp_addextendedproperty @name = N'IsTopPackage', @value=", isTopLevel,", @level0type=N'USER', @level0name=@currentUser, @level1type = N'external library', @level1name =", packageName) + } + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + + if (is.character(sqlResult)){ + return (TRUE) + } + + # error happened, vector of string contains error messages + stop(paste(sqlResult, sep = "\n")) +} + +sqlMakeTopLevel <- function(connectionString, packages, owner, scope) +{ + changeTo = 1 + haveUser <- (owner != '') + + if (haveUser) { + user = "?" + query = "" + } else { + user = "@currentUser" + query = "DECLARE @currentUser NVARCHAR(128); + SELECT @currentUser = CURRENT_USER;" + } + query = paste0(query, "EXEC sp_updateextendedproperty @name = N'IsTopPackage', @value=", changeTo,", @level0type=N'USER', + @level0name=", user, ", @level1type = N'external library', @level1name=?") + + packageList <- enumerateTopPackages(connectionString, packages, owner, scope)$name + + tryCatch({ + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + + for(pkg in intersect(packages,packageList)) { + if (haveUser) { + result <- sqlExecute(hodbc, query = query, + owner, pkg, + fetch = TRUE) + } else { + result <- sqlExecute(hodbc, query = query, + pkg, + fetch = TRUE) + } + } + }, error = function(err) { + stop(sprintf("Attribution of packages %s failed with error %s", + paste(packages, collapse = ', '), err$message), call. = FALSE) + }, finally = { + if (hodbc != -1){ + odbcClose(hodbc) + } + }) + return(TRUE) + } + +# +# Syncs packages to the file system by calling sp_execute_external_script +# checks if packages are installed on library path +# The library path is determined by the scope and the user +# +sqlSyncAndCheckInstalledPackages <- function(hodbc, packages, user = "", scope = "PRIVATE") +{ + intscope <- parseScope(scope) + + checkPackages <- function(packages, intscope) + { + success <- TRUE + resultdf <-data.frame(Package = NA, Found = NA, stringsAsFactors = FALSE) + + if (is.null(packages)) { + stop('ERROR: input package list is empty') + success <- FALSE + } + + lib <-NULL + if (intscope == 0) { + lib <-Sys.getenv('MRS_EXTLIB_SHARED_PATH') + } else if (intscope == 1) { + lib <-Sys.getenv('MRS_EXTLIB_USER_PATH') + } else { + stop(paste0('ERROR: invalid scope=', intscope)) + success <- FALSE + } + + if (success) + { + resultdf <- data.frame(Package = packages, Found = rep(FALSE, length(packages)), row.names = packages, stringsAsFactors = FALSE) + packagesFound <- find.package(packages, lib.loc = lib, quiet = TRUE) + packagesNames <- unlist(lapply(packagesFound, basename)) + + if (!is.null(packagesNames)){ + resultdf[packagesNames, 'Found'] <-TRUE + } + } + + return (resultdf) + } + + # sp_execute_external_script will first install packages to the file system + # and the run R function to check if packages installed + checkdf <- sqlRemoteExecuteFun(hodbc, checkPackages, packages, intscope, asuser = user) + + # issue warnings for packages not found + apply(checkdf, 1, function(x){ + packageName <- x[[1]] + found <- x[[2]] + if (found == FALSE){ + stop(sprintf("package failed to install to file system: package='%s', user='%s', scope='%s'", packageName, user, scope), call. = FALSE) + } + } + ) + + return(invisible(NULL)) +} + +sqlHelperInstallPackages <- function(connectionString, packages, owner = "", scope = "PRIVATE", verbose) +{ + user <- "" + + scopeint <- parseScope(scope) + + if (scopeint == 0 && owner == '') + { + # if scope is public the user has to be either dbo or member of db_owner + # if current user is already dbo we just proceed, else if user + # is member of db_owner (e.g. RevoTester) we run as 'dbo' to + # force it to install into the public folder instead of the private. + currentUser <- sqlSelectUser(connectionString); + if (currentUser == "dbo") + { + user <- ""; + } + else + { + user <- "dbo"; + } + } + else + { + user <- owner; + } + + hodbc <- -1 + haveTransaction <- FALSE + tryCatch({ + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + checkResult( odbcSetAutoCommit(hodbc, autoCommit = FALSE), 0, "failed to create transaction") + haveTransaction <- TRUE + + for (packageIndex in 1:nrow(packages)) + { + packageName <- packages[packageIndex,"Package"] + filelocation <- packages[packageIndex, "File"] + attribute <- packages[packageIndex, "Attribute"] + + sqlCreateExternalLibrary(hodbc, packageName, filelocation, user) + sqlAddExtendedProperty(hodbc, packageName, attribute, user) + } + + sqlSyncAndCheckInstalledPackages(hodbc, packages[,"Package"], user, scope); + odbcEndTran(hodbc, commit = TRUE) + } + , error = function(err) { + stop( sprintf("Installation of packages %s failed with error %s", paste(packages[,"Package"], collapse = ', '), err$message), call. = FALSE) + } + , finally = { + if(haveTransaction){ + # rollback / close open transactions otherwise odbcClose() will fail + odbcEndTran(hodbc, commit = FALSE) + } + if(hodbc != -1){ + odbcClose(hodbc) + } + } + ) + + write(sprintf("Successfully installed packages on SQL server (%s).", + paste(packages[,"Package"], collapse = ', ')), stdout()) +} + + +sqlHelperRemovePackages <- function(connectionString, pkgs, pkgsToDrop, scope, owner) +{ + user <- "" + + scopeint <- parseScope(scope) + + if (scopeint == 0 && owner == '') + { + # if scope is public the user has to be either dbo or member of db_owner + # if current user is already dbo we just proceed, else if user + # is member of db_owner (e.g. RevoTester) we run as 'dbo' to + # force it to install into the public folder instead of the private. + currentUser <- sqlSelectUser(connectionString); + if (currentUser == "dbo") + { + user <- ""; + } + else + { + user <- "dbo"; + } + } + else + { + user <- owner; + } + + hodbc <- -1 + haveTransaction <- FALSE + tryCatch({ + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + + odbcSetAutoCommit(hodbc, autoCommit = FALSE) + checkResult( odbcSetAutoCommit(hodbc, autoCommit = FALSE), 0, "failed to create transaction") + haveTransaction <- TRUE + + # first drop potentially bad packages that fails to install during SPEES + # then uninstall fully installed packages that will combine DROP + SPEES + if (length(pkgsToDrop) > 0) + { + sqlDropPackages(hodbc, pkgsToDrop, user) + } + + if (length(pkgs) > 0) + { + sqlDropPackages(hodbc, pkgs, user) + sqlSyncRemovePackage(hodbc, user) + } + + odbcEndTran(hodbc, commit = TRUE) + }, error = function(err) { + stop(sprintf("Removal of packages %s failed with error %s", paste(c(pkgs, pkgsToDrop), collapse = ', '), err$message), call. = FALSE) + }, finally = { + if(haveTransaction){ + # rollback / close open transactions otherwise odbcClose() will fail + odbcEndTran(hodbc, commit = FALSE) + } + if(hodbc != -1){ + odbcClose(hodbc) + } + }) + + write(sprintf("Successfully removed packages from SQL server (%s).", paste(c(pkgs, pkgsToDrop), collapse = ', ')), stdout()) +} + +# +# Calls sp_execute_external packages to remove packages from file system +# that we previously dropped +# +sqlSyncRemovePackage <- function(hodbc, user) +{ + noop <- function() + { + return(invisible(NULL)) + } + + sqlRemoteExecuteFun(hodbc, noop, asuser = user) + + return(invisible(NULL)) +} + +sqlDropPackages <- function(hodbc, packages, user) +{ + for (package in packages) + { + sqlDropExternalLibrary(hodbc, package, user); + } +} +# +# Returns data frame will list of packages found in sys.external_libraries +# columns in data frame are [Package][Found]. +# All submitted packages will be listed. +# If a package was found in the database, find value will be TRUE otherwise FALSE +# +sqlEnumTable <- function(connectionString, packagesNames, owner, scopeint) +{ + queryUser <- "CURRENT_USER" + + if (scopeint == 0) # public + { + currentUser <- sqlSelectUser(connectionString); + if (currentUser == "dbo") + { + queryUser = "CURRENT_USER" + } + else + { + queryUser = "'dbo'" + } + } + else if (nchar(owner) >0) + { + queryUser <- paste0("'", owner, "'") + } + + query <- paste0( + " DECLARE @currentUser NVARCHAR(128);", + " DECLARE @principalId INT;" + ) + + query <- paste0(query, " SELECT @currentUser = ", queryUser, ";") + + query <- paste0(query, + " SELECT @principalId = USER_ID(@currentUser);", + " SELECT elib.name", + " FROM sys.external_libraries AS elib", + " WHERE elib.name in (", + paste0("'", paste(packagesNames, collapse = "','"), "'"), + ")", + " AND elib.principal_id=@principalId", + " AND elib.language='R' AND elib.scope=", scopeint, + " ORDER BY elib.name ASC", + " ;" + ) + + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + on.exit(odbcClose(hodbc), add = TRUE) + + sqlResult <- sqlQuery(hodbc, query, stringsAsFactors = FALSE) + + resultdf <- data.frame(Package = packagesNames, Found = rep(FALSE, length(packagesNames)), row.names = packagesNames, stringsAsFactors = FALSE) + + if (is.data.frame(sqlResult)) + { + resultdf[sqlResult[,"name"],"Found"] <- TRUE + } + + return(resultdf) +} + +getDependentPackagesToUninstall <- function(pkgs, installedPackages, dependencies = TRUE, checkReferences = TRUE, verbose = getOption("verbose")) +{ + excludeTopMostPackagesDependencies <- function(pkgsToRemove, dependencies, db, basePackages, verbose) + { + # This function remove, from the given packages dependency lists, all the packages which are top most (and their dependencies) which are not explicitly + # stated to be removed + + prunedDependencies <- dependencies + + #If we have the topmost package information, remove, from the dependencies, packages which are marked as topmost + if ('Attributes' %in% colnames(db)){ + + # Find all the packegs , in the installed packages database, which are explicitly marked as top most + topMostInstalledPackages <- db[!is.na(db[,'Attributes']) & + bitwAnd(as.integer(db[,'Attributes']), getPackageTopMostAttributeFlag()) == 1 + ,, drop = FALSE] + + topMostDependencies <- unique(unlist(dependencies, recursive = TRUE, use.names = FALSE)) + topMostDependencies <- topMostDependencies[topMostDependencies %in% topMostInstalledPackages[,"Package"]] + + if (length(topMostDependencies) != 0){ + # Exclude, from the top most dependencies the packages which we specifically asked to remove + topMostDependencies <- topMostDependencies[!(topMostDependencies %in% pkgsToRemove)] + } + + if (length(topMostDependencies) != 0){ + # Get the top most packages dependencies to ensure they can still work + + topMostDependencies <- unique(c(unlist(tools::package_dependencies(packages = topMostDependencies, + db = db, recursive = TRUE, + verbose = FALSE), recursive = TRUE, use.names = FALSE), + topMostDependencies)) + + # Remove the dependencies which are base classes to allow the correct code to use these + topMostDependencies <- topMostDependencies[!topMostDependencies %in% basePackages] + } + + if (length(topMostDependencies) != 0){ + skippedDependencies <- character(0) + prunedDependencies <- lapply(X = dependencies, + FUN = function(dependency){ + skippedDependencies <<- c(skippedDependencies, dependency[dependency %in% topMostDependencies]) + dependency[!dependency %in% topMostDependencies] + }) + + if (verbose && length(skippedDependencies) > 0){ + write(sprintf("%s skipping following top level dependent packages (%s)...", pkgTime(), paste(unique(skippedDependencies), collapse = ', ')), stdout()) + } + } + } + + prunedDependencies + } + + # + # prune requested packages to exclude base packages + # + basePackages <- installedPackages[installedPackages[,"Priority"] %in% c("base", "recommended"), c("Package", "Priority"), drop = FALSE]$Package + + droppedPackages <- pkgs[pkgs %in% basePackages] + + if (length(droppedPackages) > 0) + { + warning(sprintf("Skipping base packages (%s)", paste(droppedPackages, collapse = ', '))) + } + + pkgs <- pkgs[!(pkgs %in% droppedPackages)] + + if (length(pkgs) < 1) + { + return (NULL) + } + + if (dependencies == FALSE) + { + dependencies = pkgs + } + else + { + # + # get dependency closure for all given packages + # + if (verbose) + { + write(sprintf("%s Resolving package dependencies for (%s)...", pkgTime(), paste(pkgs, collapse = ', ')), stdout()) + } + + dependencies <- tools::package_dependencies(packages = pkgs, db = installedPackages, recursive = TRUE, verbose = FALSE) + + # Exclude, from the package dependencies, all the packages which are marked as top most and their dependencies + dependencies <- excludeTopMostPackagesDependencies(pkgsToRemove = pkgs, + dependencies = dependencies, + db = installedPackages, + basePackages = basePackages, + verbose = verbose) + + dependencies <- c(dependencies, pkgs) + + # + # get combined dependency closure w/o base packages + # + dependencies <- unique(unlist(c(dependencies, names(dependencies)), recursive = TRUE, use.names = FALSE)) + dependencies <- dependencies[dependencies != "NA" & dependencies != ""] + dependencies <- dependencies[!(dependencies %in% basePackages)] + + if (length(dependencies) < 1) + { + return (NULL) + } + } + + if (checkReferences == TRUE) + { + # + # get reverse dependency closure for all given packages + # + if (verbose) + { + write(sprintf("%s Resolving package reverse dependencies for (%s)...", pkgTime(), paste(pkgs, collapse = ', ')), stdout()) + } + + pkgsToSkip <- list() + + for (dependency in dependencies) + { + rdependencies <- tools::package_dependencies(packages = dependency, db = installedPackages, reverse = TRUE, recursive = TRUE, verbose = FALSE) + rdependencies <- unique(unlist(c(rdependencies, names(rdependencies)), recursive = TRUE, use.names = FALSE)) + rdependencies <- rdependencies[rdependencies != "NA"] + rdependencies <- rdependencies[rdependencies != ""] + rdependencies <- rdependencies[!(rdependencies %in% dependencies)] + + if (length(rdependencies) > 0) + { + if (dependency %in% pkgs) + { + skipMessage <- sprintf("skipping package (%s) being used by packages (%s)...", + dependency, paste(rdependencies, collapse = ', ')) + warning(skipMessage) + } + else + { + skipMessage <- sprintf("skipping dependent package (%s) being used by packages (%s)...", + dependency, paste(rdependencies, collapse = ', ')) + write(skipMessage, stdout()) + } + + pkgsToSkip <- c(pkgsToSkip, dependency) + } + } + + pkgsToSkip <- unique(unlist(pkgsToSkip, recursive = TRUE, use.names = FALSE)) + + # + # remove packages which are being referenced by other packages + # + dependencies <- dependencies[!(dependencies %in% pkgsToSkip)] + + if (length(dependencies) < 1) + { + return (NULL) + } + } + + # + # get the packages in order of dependency closure + # + dependencies <- unique(dependencies) + pkgsToRemove <- installedPackages[match(dependencies, installedPackages$Package),, drop = FALSE] + pkgsToRemove <- pkgsToRemove[!is.na(pkgsToRemove$Package),] + + return (pkgsToRemove) +} + +enumerateTopPackages <- function(connectionString, packages, owner, scope) +{ + haveUser <- (owner != '') + + query <- "DECLARE @principalId INT; + DECLARE @currentUser NVARCHAR(128);" + + query <- paste0( query, paste(sapply( seq(1,length(packages)), function(i){paste0("DECLARE @pkg", toString(i), " NVARCHAR(MAX);")} ), collapse=" ")) + + query = paste0( query, "SELECT @currentUser = ") + + if (haveUser) { + query<-paste0(query, "?") + data <- data.frame(name = owner, stringsAsFactors = FALSE) + } else { + query = paste0(query, "CURRENT_USER;") + data <- data.frame(matrix(nrow=1, ncol=0), stringsAsFactors = FALSE) + } + + for(pkg in packages) + { + data <- cbind(data, pkg, stringsAsFactors = FALSE) + } + data <- cbind(data, scope = scope, stringsAsFactors = FALSE) + + + query <- paste0( query, paste(sapply( seq(1,length(packages)), function(i){paste0("SELECT @pkg", toString(i), " = ?;")} ), collapse=" ")) + pkgcsv <- paste(sapply( seq(1,length(packages)), function(i){paste0("@pkg", toString(i))} ), collapse=",") + + query = paste0(query , sprintf(" + SELECT @principalId = USER_ID(@currentUser); + WITH eprop + AS ( + SELECT piv.major_id, CAST([IsTopPackage] as bit) AS IsTopPackage FROM sys.extended_properties + PIVOT (min(value) FOR name IN ([IsTopPackage])) AS piv + WHERE class_desc = 'EXTERNAL_LIBRARY' + ) + SELECT elib.name, eprop.IsTopPackage + FROM sys.external_libraries AS elib + INNER JOIN eprop + ON eprop.major_id = elib.external_library_id AND elib.name in (%s) + AND elib.principal_id=@principalId + AND elib.language='R' AND elib.scope=? + ORDER BY elib.name ASC + ;", pkgcsv)) + + tryCatch({ + hodbc <- odbcDriverConnect(connectionString) + checkOdbcHandle(hodbc, connectionString) + + result <- sqlExecute(hodbc, query = query, + data = data, + fetch = TRUE) + }, error = function(err) { + stop(sprintf("Failed to enumerate package attributes: pkgs=(%s), error='%s'", + paste(packages, collapse = ', '), err$message), call. = FALSE) + }, finally = { + if (hodbc != -1){ + odbcClose(hodbc) + } + }) + return(result) +} diff --git a/R/R/storedProcedure.R b/R/R/storedProcedure.R new file mode 100644 index 0000000..66c90ed --- /dev/null +++ b/R/R/storedProcedure.R @@ -0,0 +1,378 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + + +#' +#'Create a Stored Procedure +#' +#'This function creates a stored procedure from a function +#'on the database and return the object. +#' +#' +#'@param connectionString character string. The connectionString to the database +#'@param name character string. The name of the stored procedure +#'@param func closure. The function to wrap in the stored procedure +#'@param inputParams named list. The types of the inputs, +#'where the names are the arguments and the values are the types +#'@param outputParams named list. The types of the outputs, +#'where the names are the arguments and the values are the types +#' +#'@section Warning: +#'You can add output parameters to the stored procedure +#'but you will not be able to execute the procedure from R afterwards. +#'Any stored procedure with output params must be executed directly in SQL. +#' +#'@examples +#'\dontrun{ +#' connectionString <- connectionInfo() +#' +#' ### Using a function +#' dropSproc(connectionString, "fun") +#' +#' func <- function(arg1) {return(data.frame(hello = arg1))} +#' createSprocFromFunction(connectionString, name = "fun", +#' func = func, inputParams = list(arg1="character")) +#' +#' if (checkSproc(connectionString, "fun")) { +#' print("Function 'fun' exists!") +#' executeSproc(connectionString, "fun", arg1="WORLD") +#' } +#' +#' ### Using a script +#' createSprocFromScript(connectionString, name = "funScript", +#' script = "path/to/script", inputParams = list(arg1="character")) +#' +#'} +#' +#' +#' +#'@seealso{ +#'\code{\link{dropSproc}} +#' +#'\code{\link{executeSproc}} +#' +#'\code{\link{checkSproc}} +#'} +#' +#'@return Invisibly returns the script used to create the stored procedure +#' +#'@describeIn createSprocFromFunction Create stored procedure from function +#'@export +createSprocFromFunction <- function (connectionString, name, func, inputParams = NULL, outputParams = NULL) { + + possibleTypes <- c("posixct", "numeric", "character", "integer", "logical", "raw", "dataframe") + + lapply(inputParams, function(x) { + if (!tolower(x) %in% possibleTypes) stop("Possible types are POSIXct, numeric, character, integer, logical, raw, and DataFrame.") + }) + lapply(outputParams, function(x) { + if (!tolower(x) %in% possibleTypes) stop("Possible types are POSIXct, numeric, character, integer, logical, raw, and DataFrame.") + }) + + inputParameters <- methods::formalArgs(func) + + if (!setequal(names(inputParams), inputParameters)){ + stop("inputParams and function arguments do not match!") + } + + procScript <- generateTSQL(func = func, spName = name, inputParams = inputParams, outputParams = outputParams) + + tryCatch({ + register(procScript, connectionString = connectionString) + }, error = function(e) { + stop(paste0("Failed during registering procedure ", name, ": ", e)) + }) + + invisible(procScript) +} + +#'@describeIn createSprocFromFunction Create stored procedure from script file, returns output of final line +#' +#'@param script character string. The path to the script to wrap in the stored procedure +#'@export +createSprocFromScript <- function (connectionString, name, script, inputParams = NULL, outputParams = NULL) { + + if (file.exists(script)){ + print(paste0("Script path exists, using file ", script)) + } else { + stop("Script path doesn't exist") + } + + text <- paste(readLines(script), collapse="\n") + + possibleTypes = c("posixct", "numeric", "character", "integer", "logical", "raw", "dataframe") + + lapply(inputParams, function(x) { + if (!tolower(x) %in% possibleTypes) stop("Possible input types are POSIXct, numeric, character, integer, logical, raw, and DataFrame.") + }) + lapply(outputParams, function(x) { + if (!tolower(x) %in% possibleTypes) stop("Possible output types are POSIXct, numeric, character, integer, logical, raw, and DataFrame.") + }) + + procScript <- generateTSQLFromScript(script = text, spName = name, inputParams = inputParams, outputParams = outputParams) + + tryCatch({ + register(procScript, connectionString = connectionString) + }, error = function(e) { + stop(paste0("Failed during registering procedure ", name, ": ", e)) + }) + + invisible(procScript) +} + + +#'Drop Stored Procedure +#' +#'@param connectionString character string. The connectionString to the database +#'@param name character string. The name of the stored procedure +#' +#'@examples +#'\dontrun{ +#' connectionString <- connectionInfo() +#' +#' dropSproc(connectionString, "fun") +#' +#' func <- function(arg1) {return(data.frame(hello = arg1))} +#' createSprocFromFunction(connectionString, name = "fun", +#' func = func, inputParams = list(arg1 = "character")) +#' +#' if (checkSproc(connectionString, "fun")) { +#' print("Function 'fun' exists!") +#' executeSproc(connectionString, "fun", arg1="WORLD") +#' } +#'} +#' +#' +#'@seealso{ +#' +#'\code{\link{createSprocFromFunction}} +#' +#'\code{\link{executeSproc}} +#' +#'\code{\link{checkSproc}} +#'} +#' +#'@importFrom RODBCext sqlExecute +#'@import RODBC +#' +#'@export +dropSproc <- function(connectionString, name) { + tryCatch({ + dbhandle <- odbcDriverConnect(connectionString) + output <- sqlExecute(dbhandle, "SELECT OBJECT_ID (?)", name, fetch=TRUE) + if (!is.na(output)) { + output <- sqlQuery(dbhandle, sprintf("DROP PROCEDURE %s", name)) + } else { + output <- "Named procedure doesn't exist" + } + }, error = function(e) { + stop(paste0("Error dropping the stored procedure\n")) + }, finally = { + odbcCloseAll() + }) + + if (length(output) > 0) { + print(output) + return(FALSE) + } else { + print(paste0("Successfully dropped procedure ", name)) + return(TRUE) + } +} + +#'Check if Stored Procedure is in Database +#' +#'@param connectionString character string. The connectionString to the database +#'@param name character string. The name of the stored procedure +#' +#'@return Whether the stored procedure exists in the database +#' +#'@examples +#'\dontrun{ +#' connectionString <- connectionInfo() +#' +#' dropSproc(connectionString, "fun") +#' +#' func <- function(arg1) {return(data.frame(hello = arg1))} +#' createSprocFromFunction(connectionString, name = "fun", +#' func = func, inputParams = list(arg1="character")) +#' if (checkSproc(connectionString, "fun")) { +#' print("Function 'fun' exists!") +#' executeSproc(connectionString, "fun", arg1="WORLD") +#' } +#'} +#' +#' +#'@seealso{ +#'\code{\link{createSprocFromFunction}} +#' +#'\code{\link{dropSproc}} +#' +#'\code{\link{executeSproc}} +#' +#'} +#' +#'@importFrom RODBCext sqlExecute +#'@import RODBC +#'@export +checkSproc <- function(connectionString, name) { + tryCatch({ + dbhandle <- odbcDriverConnect(connectionString) + output <- sqlExecute(dbhandle, "SELECT OBJECT_ID (?, N'P')", name, fetch = TRUE) + }, error = function(e) { + cat(paste0("Error executing the sqlExecute\n")) + }, finally = { + odbcCloseAll() + }) + if (is.na(output)) { + return(FALSE) + } else { + return(TRUE) + } +} + +#'Execute a Stored Procedure +#' +#'@param connectionString character string. The connectionString for the database with the stored procedure +#'@param name character string. The name of the stored procedure in the database to execute +#'@param ... named list. Parameters to pass into the procedure. These MUST be named the same as the arguments to the function. +#' +#'@section Warning: +#'Even though you can create stored procedures with output parameters, you CANNOT execute them +#'using this utility due to limitations of RODBC. +#' +#'@examples +#'\dontrun{ +#' connectionString <- connectionInfo() +#' +#' dropSproc(connectionString, "fun") +#' +#' func <- function(arg1) {return(data.frame(hello = arg1))} +#' createSprocFromFunction(connectionString, name = "fun", +#' func = func, inputParams = list(arg1="character")) +#' +#' if (checkSproc(connectionString, "fun")) { +#' print("Function 'fun' exists!") +#' executeSproc(connectionString, "fun", arg1="WORLD") +#' } +#'} +#'@seealso{ +#'\code{\link{createSprocFromFunction}} +#' +#'\code{\link{dropSproc}} +#' +#'\code{\link{checkSproc}} +#'} +#'@importFrom RODBCext sqlExecute +#'@import RODBC +#'@export +executeSproc <- function(connectionString, name, ...) { + if (class(name) != "character") + stop("the argument must be the name of a Sproc") + + res <- createQuery(connectionString = connectionString, name = name, ...) + query <- res$query + paramOrder <- res$inputParams + df = data.frame(...) + + if (nrow(df) != 0 && ncol(df) != 0) { + df <- df[paramOrder] + } + + tryCatch({ + dbhandle <- odbcDriverConnect(connectionString) + result <- sqlExecute(dbhandle, query, df, fetch = TRUE) + }, error = function(e) { + stop(paste0("Error in SQL Execution: ", e, "\n")) + }, finally ={ + odbcCloseAll() + }) + + if (is.list(result)) { + return(result) + } else if (!is.character(result)) { + stop(paste("Error executing the stored procedure:", name)) + } else { + return(NULL) + } +} + +# +# Get the parameters of the stored procedure to create the query +# +#@param connectionString character string. The connectionString to the database +#@param name character string. The name of the stored procedure +# +#@return the parameters +# +getSprocParams <- function(connectionString, name) { + query <- "SELECT 'Parameter_name' = name, 'Type' = type_name(user_type_id), + 'Output' = is_output FROM sys.parameters WHERE OBJECT_ID = ?" + + inputDataName <- NULL + tryCatch({ + dbhandle <- odbcDriverConnect(connectionString) + + number <- sqlExecute(dbhandle, "SELECT OBJECT_ID (?)", name, fetch=TRUE)[[1]] + + params <- sqlExecute(dbhandle, query, number, fetch=TRUE) + outputParams <- split(params,params$Output)[['1']] + inputParams <- split(params,params$Output)[['0']] + + text <- paste0(collapse="", lapply(sqlExecute(dbhandle, "EXEC sp_helptext ?", name, fetch = TRUE), as.character)) + matched <- regmatches(text, gregexpr("input_data_1_name = [^,]+",text))[[1]] + if (length(matched) == 1) { + inputDataName <- regmatches(matched, gregexpr("N'.*'",matched))[[1]] + inputDataName <- gsub("(N'|')","", inputDataName) + } + }, error = function(e) { + cat(paste0("Error executing the sqlExecute\n")) + odbcCloseAll() + stop(e) + }, finally ={ + odbcCloseAll() + }) + list(inputParams = inputParams, inputDataName = inputDataName, outputParams = outputParams) +} + +#Create the necessary query to execute the stored procedure +# +#@param connectionString character string. The connectionString to the database +#@param name character string. The name of the stored procedure +#@param ... The arguments for the stored procedure +# +#@return the query +# +createQuery <- function(connectionString, name, ...) { + #Get and process params from the stored procedure in the database + storedProcParams <- getSprocParams(connectionString = connectionString, name = name) + params <- storedProcParams$inputParams + inList <- c() + + if (!is.null(params)) { + for(i in seq_len(nrow(params))) { + parameter_outer <- params[i,]$Parameter_name + parameter <- gsub('.{6}$', '', parameter_outer) + parameter <- gsub('@','', parameter) + type <- params[i,]$Type + + inList <- c(inList,parameter) + } + } + inLabels <- NULL + if (!(length(list(...)) == 1 && is.null(list(...)[[1]]))) { + inLabels <- labels(list(...)) + if (!all(inLabels %in% inList)) { + stop("You must provide named arguments that match the parameters in the stored procedure.") + } + } + #add necessary variable declarations and value assignments + + query <- paste0("exec ", name) + for(p in inList) { + paramName <- p + query <- paste0(query, " @", paramName, "_outer = ?,") + } + query <- gsub(",$", "", query) + list(query=query, inputParams=inList) +} diff --git a/R/R/storedProcedureScripting.R b/R/R/storedProcedureScripting.R new file mode 100644 index 0000000..e259c7f --- /dev/null +++ b/R/R/storedProcedureScripting.R @@ -0,0 +1,220 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + + +# the list with type conversion info +sqlTypes <- list(posixct = "datetime", numeric = "float", + character = "nvarchar(max)", integer = "int", + logical = "bit", raw = "varbinary(max)", dataframe = "nvarchar(max)") + +getSqlType <- function(rType) { + sqlTypes[[tolower(rType)]] +} + +# creates the top part of the sql script (up to R code) +getHeader <- function(spName, inputParams, outputParams) { + header <- c(paste0 ("CREATE PROCEDURE ", spName), + handleHeadParams(inputParams, outputParams), + "AS", + "BEGIN TRY", + "exec sp_execute_external_script", + "@language = N'R',","@script = N'") + return(paste0(header, collapse = "\n")) +} + +handleHeadParams <- function(inputParams, outputParams) +{ + paramString <- c() + + makeString <- function(name, d, output = "") { + rType <- d[[name]] + sqlType <- getSqlType(rType) + paste0(" @", name, "_outer ", sqlType, output) + } + + for(name in names(inputParams)) { + paramString <- c(paramString, makeString(name, inputParams)) + } + for(name in names(outputParams)) { + rType <- outputParams[[name]] + if (tolower(rType) != "dataframe") { + paramString <- c(paramString, makeString(name, outputParams, " output")) + } + } + return(paste0(paramString, collapse = ",\n")) +} + +generateTSQL <- function(func, spName, inputParams = NULL, outputParams = NULL ) { + # header to drop and create a stored procedure + header <- getHeader(spName, inputParams, outputParams) + + # vector containing R code + rCode <- getRCode(func, outputParams) + + # tail of the sp + tail <- getTail(inputParams, outputParams) + + register = paste0(header, rCode, tail, sep = "\n") +} + +generateTSQLFromScript <- function(script, spName, inputParams, outputParams) { + # header to drop and create a stored procedure + header <- getHeader(spName, inputParams = inputParams, outputParams = outputParams) + + # vector containing R code + rCode <- getRCodeFromScript(script = script, outputParams = outputParams) + + # tail of the sp + tail <- getTail(inputParams = inputParams, outputParams = outputParams) + + paste0(header, rCode, tail, sep = "\n") +} + + + +# creates the bottom part of the sql script (after R code) +getTail <- function(inputParams, outputParams) { + tail <- c("'") + tailParams <- handleTailParams(inputParams, outputParams) + if (tailParams != "") + tail <- c("',") + tail <- c(tail, + tailParams, + "END TRY", + "BEGIN CATCH", + "THROW;", + "END CATCH;") + return(paste0(tail, collapse = "\n")) +} + +handleTailParams <- function(inputParams, outputParams) { + inDataString <- c() + outDataString <- c() + paramString <- c() + overallParams <- c() + + makeString <- function(name, d, output = "") { + rType <- d[[name]] + if (tolower(rType) == "dataframe") { + if (output=="") { + c(paste0("@input_data_1 = @", name, "_outer"), + paste0("@input_data_1_name = N'", name, "'")) + } else { + c(paste0("@output_data_1_name = N'", name, "'")) + } + } else { + sqlType <- getSqlType(rType) + overallParams <<- c(overallParams, paste0("@", name, " ", sqlType, output)) + paste0("@", name, " = ", "@", name, "_outer", output) + } + } + + for(name in names(inputParams)) { + rType <- inputParams[[name]] + if (tolower(rType) == "dataframe") { + inDataString <- c(makeString(name, inputParams)) + } else { + paramString <- c(paramString, makeString(name, inputParams)) + } + } + for(name in names(outputParams)) { + rType <- outputParams[[name]] + if (tolower(rType) == "dataframe") { + outDataString <- c(makeString(name, outputParams, " output")) + } else { + paramString <- c(paramString, makeString(name, outputParams, " output")) + } + } + if (length(overallParams) > 0) { + overallParams <- paste0(overallParams, collapse = ", ") + overallParams <- paste0("@params = N'" , overallParams,"'") + } + return(paste0(c(inDataString, outDataString, overallParams, paramString), collapse = ",\n")) +} + +getRCodeFromScript <- function(script, inputParams, outputParams) { + # escape single quotes and get rid of tabs + script <- sapply(script, gsub, pattern = "\t", replacement = " ") + script <- sapply(script, gsub, pattern = "'", replacement = "''") + + return(paste0(script, collapse = "\n")) +} + +getRCode <- function(func, outputParams) { + name <- as.character(substitute(func)) + + funcBody <- deparse(func) + + # add on the function definititon + funcBody[1] <- paste(name, "<-", funcBody[1], sep = " ") + + # escape single quotes and get rid of tabs + funcBody <- sapply(funcBody, gsub, pattern = "\t", replacement = " ") + funcBody <- sapply(funcBody, gsub, pattern = "'", replacement = "''") + + inputParameters <- methods::formalArgs(func) + + funcInputNames <- paste(inputParameters, inputParameters, + sep = " = ") + funcInputNames <- paste(funcInputNames, collapse = ", ") + + # add function call + funcBody <- c(funcBody, paste0("result <- ", name, + paste0("(", funcInputNames, ")"))) + + # add appropriate ending + ending <- getEnding(outputParams) + funcBody <- c(funcBody, ending) + return(paste0(funcBody, collapse = "\n")) +} + +# +# Get ending string +# We change the result into an OutputDataSet - we only expect a single OutputDataSet result +getEnding <- function(outputParams) { + outputDataSetName <- "OutputDataSet" + for(name in names(outputParams)) { + if (tolower(outputParams[[name]]) == "dataframe") { + outputDataSetName <- name + } + } + ending <- c( "if (is.data.frame(result)) {", + paste0(" ", outputDataSetName," <- result") + ) + + if (length(outputParams) > 0) { + ending <- c(ending, "} else if (is.list(result)) {") + + for(name in names(outputParams)) { + if (tolower(outputParams[[name]]) == "dataframe") { + ending <- c(ending,paste0(" ", name," <- result$", name)) + } else { + ending <- c(ending,paste0(" ", name, " <- result$", name)) + } + } + ending <- c(ending, + "} else if (!is.null(result)) {", + " stop(\"the R function must return a list\")" + ) + } + ending <- c(ending, "}") +} + +# @import RODBC +# Execute the registration script +register <- function(registrationScript, connectionString) { + output <- character(0) + + tryCatch({ + dbhandle <- odbcDriverConnect(connectionString) + output <- sqlQuery(dbhandle, registrationScript) + }, error = function(e) { + stop(paste0("Error in SQL Execution:\n", e)) + }, finally ={ + odbcCloseAll() + }) + if (length(output) > 0 ) { + stop(output) + } +} + diff --git a/R/README.md b/R/README.md new file mode 100644 index 0000000..54c3f71 --- /dev/null +++ b/R/README.md @@ -0,0 +1,151 @@ +# sqlmlutils + +sqlmlutils is an R package to help execute R code on a SQL Server machine. + +# Installation + +Run +``` +R CMD INSTALL dist/sqlmlutils_0.5.0.zip +``` +OR +To build a new package file and install, run +``` +.\buildandinstall.cmd +``` + +# Getting started + +Shown below are the important functions sqlmlutils provides: +```R +connectionInfo # Create a connection string for connecting to the SQL Server + +executeFunctionInSQL # Execute an R function inside the SQL database +executeScriptInSQL # Execute an R script inside the SQL database +executeSQLQuery # Execute a SQL query on the database and return the resultant table + +createSprocFromFunction # Create a stored procedure based on a R function inside the SQL database +createSprocFromScript # Create a stored procedure based on a R script inside the SQL database +checkSproc # Check whether a stored procedure exists in the SQL database +dropSproc # Drop a stored procedure in the SQL database +executeSproc # Execute a stored procedure in the SQL database + +sql_install.packages # Install packages in the SQL database +sql_remove.packages # Remove packages from the SQL database +sql_installed.packages # Enumerate packages that are installed on the SQL database +``` + +# Examples + +### Execute In SQL +##### Execute an R function in database using sp_execute_external_script + +```R +library(sqlmlutils) +connection <- connectionInfo() + +funcWithArgs <- function(arg1, arg2){ + return(c(arg1, arg2)) +} +result <- executeFunctionInSQL(connection, funcWithArgs, arg1="result1", arg2="result2") +``` + +##### Generate a linear model without the data leaving the machine + +```R +library(sqlmlutils) +connection <- connectionInfo(database="AirlineTestDB") + +linearModel <- function(in_df, xCol, yCol) { + lm(paste0(yCol, " ~ ", xCol), in_df) +} + +model <- executeFunctionInSQL(connectionString = connection, func = linearModel, xCol = "CRSDepTime", yCol = "ArrDelay", + inputDataQuery = "SELECT TOP 100 * FROM airline5000") +model +``` + +##### Execute a SQL Query from R + +```R +library(sqlmlutils) +connection <- connectionInfo(database="AirlineTestDB") + +dataTable <- executeSQLQuery(connectionString = connection, sqlQuery="SELECT TOP 100 * FROM airline5000") +stopifnot(nrow(dataTable) == 100) +stopifnot(ncol(dataTable) == 30) +``` + +### Stored Procedures (Sproc) +##### Create and call a T-SQL stored procedure based on a R function + +```R +library(sqlmlutils) + +spPredict <- function(inputDataFrame) { + library(RevoScaleR) + model <- rxLinMod(ArrDelay ~ CRSDepTime, inputDataFrame) + rxPredict(model, inputDataFrame) +} + +connection <- connectionInfo(database="AirlineTestDB") +inputParams <- list(inputDataFrame = "Dataframe") + +name = "prediction" + +createSprocFromFunction(connectionString = connection, name = name, func = spPredict, inputParams = inputParams) +stopifnot(checkSproc(connectionString = connection, name = name)) + +predictions <- executeSproc(connectionString = connection, name = name, inputDataFrame = "select ArrDelay, CRSDepTime, DayOfWeek from airline5000") +stopifnot(nrow(predictions) == 5000) + +dropSproc(connectionString = connection, name = name) +``` + +### Package Management +##### Install and remove packages from SQL Server + +```R +library(sqlmlutils) +connection <- connectionInfo(database="AirlineTestDB") + +# install glue on sql server +pkgs <- c("glue") +sql_install.packages(connectionString = connection, pkgs, verbose = TRUE, scope="PUBLIC") + +# confirm glue is installed on sql server +r<-sql_installed.packages(connectionString = connection, fields=c("Package", "LibPath", "Attributes", "Scope")) +View(r) + +# use glue on sql server +useLibraryGlueInSql <- function() +{ + library(glue) + + name <- "Fred" + age <- 50 + anniversary <- as.Date("1991-10-12") + glue('My name is {name},', + 'my age next year is {age + 1},', + 'my anniversary is {format(anniversary, "%A, %B %d, %Y")}.') +} + +result <- executeFunctionInSQL(connectionString = connection, func = useLibraryGlueInSql) +print(result) + +# remove glue from sql server +sql_remove.packages(connectionString = connection, pkgs, scope="PUBLIC") +``` + +# Notes for Developers + +### Running the tests + +1. Make sure a SQL Server with an updated ML Services R is running on localhost. +2. Restore the AirlineTestDB from the .bak file in this repo +3. Make sure Trusted (Windows) authentication works for connecting to the database + +### Notable TODOs and open issues + +1. Output Parameter execution does not work - RODBCext limitations? +2. Testing from a Linux client has not been performed. diff --git a/R/buildandinstall.cmd b/R/buildandinstall.cmd new file mode 100644 index 0000000..4cf8fac --- /dev/null +++ b/R/buildandinstall.cmd @@ -0,0 +1,5 @@ +pushd . +cd .. +R CMD INSTALL --build R +mv sqlmlutils_0.5.0.zip R/dist +popd diff --git a/R/dist/sqlmlutils_0.5.0.zip b/R/dist/sqlmlutils_0.5.0.zip new file mode 100644 index 0000000..2d9e9b4 Binary files /dev/null and b/R/dist/sqlmlutils_0.5.0.zip differ diff --git a/R/man/checkSproc.Rd b/R/man/checkSproc.Rd new file mode 100644 index 0000000..2b843f7 --- /dev/null +++ b/R/man/checkSproc.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/storedProcedure.R +\name{checkSproc} +\alias{checkSproc} +\title{Check if Stored Procedure is in Database} +\usage{ +checkSproc(connectionString, name) +} +\arguments{ +\item{connectionString}{character string. The connectionString to the database} + +\item{name}{character string. The name of the stored procedure} +} +\value{ +Whether the stored procedure exists in the database +} +\description{ +Check if Stored Procedure is in Database +} +\examples{ +\dontrun{ +connectionString <- connectionInfo() + +dropSproc(connectionString, "fun") + +func <- function(arg1) {return(data.frame(hello = arg1))} +createSprocFromFunction(connectionString, name = "fun", + func = func, inputParams = list(arg1="character")) +if (checkSproc(connectionString, "fun")) { + print("Function 'fun' exists!") + executeSproc(connectionString, "fun", arg1="WORLD") +} +} + + +} +\seealso{ +{ +\code{\link{createSprocFromFunction}} + +\code{\link{dropSproc}} + +\code{\link{executeSproc}} + +} +} diff --git a/R/man/connectionInfo.Rd b/R/man/connectionInfo.Rd new file mode 100644 index 0000000..9654e83 --- /dev/null +++ b/R/man/connectionInfo.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/executeInSQL.R +\name{connectionInfo} +\alias{connectionInfo} +\title{Execute a function in SQL} +\usage{ +connectionInfo(driver = "SQL Server", server = "localhost", + database = "master", uid = NULL, pwd = NULL) +} +\arguments{ +\item{driver}{The driver to use for the connection - defaults to SQL Server} + +\item{server}{The server to connect to - defaults to localhost} + +\item{database}{The database to connect to - defaults to master} + +\item{uid}{The user id for the connection. If uid is NULL, default to Trusted Connection} + +\item{pwd}{The password for the connection. If uid is not NULL, pwd is required} +} +\value{ +A fully formed connection string +} +\description{ +Execute a function in SQL +} +\examples{ +\dontrun{ + +connectionInfo() +[1] "Driver={SQL Server};Server=localhost;Database=master;Trusted_Connection=Yes;" + +connectionInfo(server="ServerName", database="AirlineTestDB", uid="username", pwd="pass") +[1] "Driver={SQL Server};Server=ServerName;Database=AirlineTestDB;uid=username;pwd=pass;" +} + + +} diff --git a/R/man/createSprocFromFunction.Rd b/R/man/createSprocFromFunction.Rd new file mode 100644 index 0000000..fb9c299 --- /dev/null +++ b/R/man/createSprocFromFunction.Rd @@ -0,0 +1,83 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/storedProcedure.R +\name{createSprocFromFunction} +\alias{createSprocFromFunction} +\alias{createSprocFromScript} +\title{Create a Stored Procedure} +\usage{ +createSprocFromFunction(connectionString, name, func, inputParams = NULL, + outputParams = NULL) + +createSprocFromScript(connectionString, name, script, inputParams = NULL, + outputParams = NULL) +} +\arguments{ +\item{connectionString}{character string. The connectionString to the database} + +\item{name}{character string. The name of the stored procedure} + +\item{func}{closure. The function to wrap in the stored procedure} + +\item{inputParams}{named list. The types of the inputs, +where the names are the arguments and the values are the types} + +\item{outputParams}{named list. The types of the outputs, +where the names are the arguments and the values are the types} + +\item{script}{character string. The path to the script to wrap in the stored procedure} +} +\value{ +Invisibly returns the script used to create the stored procedure +} +\description{ +This function creates a stored procedure from a function +on the database and return the object. +} +\section{Functions}{ +\itemize{ +\item \code{createSprocFromFunction}: Create stored procedure from function + +\item \code{createSprocFromScript}: Create stored procedure from script file, returns output of final line +}} + +\section{Warning}{ + +You can add output parameters to the stored procedure +but you will not be able to execute the procedure from R afterwards. +Any stored procedure with output params must be executed directly in SQL. +} + +\examples{ +\dontrun{ +connectionString <- connectionInfo() + +### Using a function +dropSproc(connectionString, "fun") + +func <- function(arg1) {return(data.frame(hello = arg1))} +createSprocFromFunction(connectionString, name = "fun", + func = func, inputParams = list(arg1="character")) + +if (checkSproc(connectionString, "fun")) { + print("Function 'fun' exists!") + executeSproc(connectionString, "fun", arg1="WORLD") +} + +### Using a script +createSprocFromScript(connectionString, name = "funScript", + script = "path/to/script", inputParams = list(arg1="character")) + +} + + + +} +\seealso{ +{ +\code{\link{dropSproc}} + +\code{\link{executeSproc}} + +\code{\link{checkSproc}} +} +} diff --git a/R/man/dropSproc.Rd b/R/man/dropSproc.Rd new file mode 100644 index 0000000..7485282 --- /dev/null +++ b/R/man/dropSproc.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/storedProcedure.R +\name{dropSproc} +\alias{dropSproc} +\title{Drop Stored Procedure} +\usage{ +dropSproc(connectionString, name) +} +\arguments{ +\item{connectionString}{character string. The connectionString to the database} + +\item{name}{character string. The name of the stored procedure} +} +\description{ +Drop Stored Procedure +} +\examples{ +\dontrun{ +connectionString <- connectionInfo() + +dropSproc(connectionString, "fun") + +func <- function(arg1) {return(data.frame(hello = arg1))} +createSprocFromFunction(connectionString, name = "fun", + func = func, inputParams = list(arg1 = "character")) + +if (checkSproc(connectionString, "fun")) { + print("Function 'fun' exists!") + executeSproc(connectionString, "fun", arg1="WORLD") +} +} + + +} +\seealso{ +{ + +\code{\link{createSprocFromFunction}} + +\code{\link{executeSproc}} + +\code{\link{checkSproc}} +} +} diff --git a/R/man/executeFunctionInSQL.Rd b/R/man/executeFunctionInSQL.Rd new file mode 100644 index 0000000..ff2b914 --- /dev/null +++ b/R/man/executeFunctionInSQL.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/executeInSQL.R +\name{executeFunctionInSQL} +\alias{executeFunctionInSQL} +\title{Execute a function in SQL} +\usage{ +executeFunctionInSQL(connectionString, func, ..., inputDataQuery = "") +} +\arguments{ +\item{connectionString}{character string. The connectionString to the database} + +\item{func}{closure. The function to execute} + +\item{...}{A named list of arguments to pass into the function} + +\item{inputDataQuery}{character string. A string to query the database. +The result of the query will be put into a data frame into the first argument in the function} +} +\value{ +The returned value from the function +} +\description{ +Execute a function in SQL +} +\examples{ +\dontrun{ +connection <- connectionInfo(database = "AirlineTestDB") + +foo <- function(in_df, arg) { + list(data = in_df, value = arg) +} +executeFunctionInSQL(connection, foo, arg = 12345, + inputDataQuery = "SELECT top 1 * from airline5000") +} + + +} +\seealso{ +\code{\link{executeScriptInSQL}} to execute a script file instead of a function in SQL +} diff --git a/R/man/executeSQLQuery.Rd b/R/man/executeSQLQuery.Rd new file mode 100644 index 0000000..eb274be --- /dev/null +++ b/R/man/executeSQLQuery.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/executeInSQL.R +\name{executeSQLQuery} +\alias{executeSQLQuery} +\title{Execute a script in SQL} +\usage{ +executeSQLQuery(connectionString, sqlQuery) +} +\arguments{ +\item{connectionString}{character string. The connectionString to the database} + +\item{sqlQuery}{character string. The query to execute} +} +\value{ +The data frame returned by the query to the database +} +\description{ +Execute a script in SQL +} +\examples{ +\dontrun{ +connection <- connectionInfo(database="AirlineTestDB") +executeSQLQuery(connection, sqlQuery="SELECT top 1 * from airline5000") +} + + +} diff --git a/R/man/executeScriptInSQL.Rd b/R/man/executeScriptInSQL.Rd new file mode 100644 index 0000000..e495fd9 --- /dev/null +++ b/R/man/executeScriptInSQL.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/executeInSQL.R +\name{executeScriptInSQL} +\alias{executeScriptInSQL} +\title{Execute a script in SQL} +\usage{ +executeScriptInSQL(connectionString, script, inputDataQuery = "") +} +\arguments{ +\item{connectionString}{character string. The connectionString to the database} + +\item{script}{character string. The path to the script to execute in SQL} + +\item{inputDataQuery}{character string. A string to query the database. +The result of the query will be put into a data frame into the variable "InputDataSet" in the environment} +} +\value{ +The returned value from the last line of the script +} +\description{ +Execute a script in SQL +} +\seealso{ +\code{\link{executeFunctionInSQL}} to execute a user function instead of a script in SQL +} diff --git a/R/man/executeSproc.Rd b/R/man/executeSproc.Rd new file mode 100644 index 0000000..df3c9d6 --- /dev/null +++ b/R/man/executeSproc.Rd @@ -0,0 +1,49 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/storedProcedure.R +\name{executeSproc} +\alias{executeSproc} +\title{Execute a Stored Procedure} +\usage{ +executeSproc(connectionString, name, ...) +} +\arguments{ +\item{connectionString}{character string. The connectionString for the database with the stored procedure} + +\item{name}{character string. The name of the stored procedure in the database to execute} + +\item{...}{named list. Parameters to pass into the procedure. These MUST be named the same as the arguments to the function.} +} +\description{ +Execute a Stored Procedure +} +\section{Warning}{ + +Even though you can create stored procedures with output parameters, you CANNOT execute them +using this utility due to limitations of RODBC. +} + +\examples{ +\dontrun{ +connectionString <- connectionInfo() + +dropSproc(connectionString, "fun") + +func <- function(arg1) {return(data.frame(hello = arg1))} +createSprocFromFunction(connectionString, name = "fun", + func = func, inputParams = list(arg1="character")) + +if (checkSproc(connectionString, "fun")) { + print("Function 'fun' exists!") + executeSproc(connectionString, "fun", arg1="WORLD") +} +} +} +\seealso{ +{ +\code{\link{createSprocFromFunction}} + +\code{\link{dropSproc}} + +\code{\link{checkSproc}} +} +} diff --git a/R/man/sql_install.packages.Rd b/R/man/sql_install.packages.Rd new file mode 100644 index 0000000..050abd2 --- /dev/null +++ b/R/man/sql_install.packages.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sqlPackage.R +\name{sql_install.packages} +\alias{sql_install.packages} +\title{sql_install.packages} +\usage{ +sql_install.packages(connectionString, pkgs, skipMissing = FALSE, + repos = getOption("repos"), verbose = getOption("verbose"), + scope = "private", owner = "") +} +\arguments{ +\item{connectionString}{ODBC connection string to Microsoft SQL Server database.} + +\item{pkgs}{character vector of the names of packages whose current versions should be downloaded from the repositories. If repos = NULL, a character vector of file paths of .zip files containing binary builds of packages. (http:// and file:// URLs are also accepted and the files will be downloaded and installed from local copies).} + +\item{skipMissing}{logical. If TRUE, skips missing dependent packages for which otherwise an error is generated.} + +\item{repos}{character vector, the base URL(s) of the repositories to use.Can be NULL to install from local files, directories.} + +\item{verbose}{logical. If TRUE, more detailed information is given during installation of packages.} + +\item{scope}{character string. Should be either "public" or "private". "public" installs the packages on per database public location on SQL server which in turn can be used (referred) by multiple different users. "private" installs the packages on per database, per user private location on SQL server which is only accessible to the single user.} + +\item{owner}{character string. Should be either empty '' or a valid SQL database user account name. Only 'dbo' or users in 'db_owner' role for a database can specify this value to install packages on behalf of other users. A user who is member of the 'db_owner' group can set owner='dbo' to install on the "public" folder.} +} +\value{ +invisible(NULL) +} +\description{ +Installs R packages on a SQL Server database. Packages are downloaded on the client and then copied and installed to SQL Server into "public" and "private" folders. Packages in the "public" folders can be loaded by all database users running R script in SQL. Packages in the "private" folder can be loaded only by a single user. 'dbo' users always install into the "public" folder. Users who are members of the 'db_owner' role can install to both "public" and "private" folders. All other users can only install packages to their "private" folder. +} +\seealso{ +{ +\code{\link{sql_remove.packages}} to remove packages + +\code{\link{sql_installed.packages}} to enumerate the installed packages + +\code{\link{install.packages}} for the base version of this function +} +} diff --git a/R/man/sql_installed.packages.Rd b/R/man/sql_installed.packages.Rd new file mode 100644 index 0000000..34f434d --- /dev/null +++ b/R/man/sql_installed.packages.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sqlPackage.R +\name{sql_installed.packages} +\alias{sql_installed.packages} +\title{sql_installed.packages} +\usage{ +sql_installed.packages(connectionString, priority = NULL, noCache = FALSE, + fields = "Package", subarch = NULL, scope = "private", owner = "") +} +\arguments{ +\item{connectionString}{ODBC connection string to Microsoft SQL Server database.} + +\item{priority}{character vector or NULL (default). If non-null, used to select packages; "high" is equivalent to c("base", "recommended"). To select all packages without an assigned priority use priority = "NA".} + +\item{noCache}{logical. If TRUE, do not use cached information, nor cache it.} + +\item{fields}{a character vector giving the fields to extract from each package's DESCRIPTION file, or NULL. If NULL, the following fields are used: +"Package", "LibPath", "Version", "Priority", "Depends", "Imports", "LinkingTo", "Suggests", "Enhances", "License", "License_is_FOSS", "License_restricts_use", "OS_type", "MD5sum", "NeedsCompilation", and "Built". +Unavailable fields result in NA values.} + +\item{subarch}{character string or NULL. If non-null and non-empty, used to select packages which are installed for that sub-architecture} + +\item{scope}{character string which can be "private" or "public".} + +\item{owner}{character string of a user whose private packages shall be listed (availableto dbo or db_owner users only)} +} +\value{ +matrix with enumerated packages +} +\description{ +Enumerates the currently installed R packages on a SQL Server for the current database +} +\seealso{ +{ +\code{\link{sql_install.packages}} to install packages + +\code{\link{sql_remove.packages}} to remove packages + +\code{\link{installed.packages}} for the base version of this function + +} +} diff --git a/R/man/sql_remove.packages.Rd b/R/man/sql_remove.packages.Rd new file mode 100644 index 0000000..babc42d --- /dev/null +++ b/R/man/sql_remove.packages.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sqlPackage.R +\name{sql_remove.packages} +\alias{sql_remove.packages} +\title{sql_remove.packages} +\usage{ +sql_remove.packages(connectionString, pkgs, dependencies = TRUE, + checkReferences = TRUE, verbose = getOption("verbose"), + scope = "private", owner = "") +} +\arguments{ +\item{connectionString}{ODBC connection string to SQL Server database.} + +\item{pkgs}{character vector of names of the packages to be removed.} + +\item{dependencies}{logical. If TRUE, does dependency resolution of the packages being removed and removes the dependent packages also if the dependent packages aren't referenced by other packages outside the dependency closure.} + +\item{checkReferences}{logical. If TRUE, verifies there are no references to the dependent packages by other packages outside the dependency closure. Use FALSE to force removal of packages even when other packages depend on it.} + +\item{verbose}{logical. If TRUE, more detailed information is given during removal of packages.} + +\item{scope}{character string. Should be either "public" or "private". "public" removes the packages from a per-database public location on SQL Server which in turn could have been used (referred) by multiple different users. "private" removes the packages from a per-database, per-user private location on SQL Server which is only accessible to the single user.} + +\item{owner}{character string. Should be either empty '' or a valid SQL database user account name. Only 'dbo' or users in 'db_owner' role for a database can specify this value to remove packages on behalf of other users. A user who is member of the 'db_owner' group can set owner='dbo' to remove packages from the "public" folder.} +} +\value{ +invisible(NULL) +} +\description{ +sql_remove.packages +} +\seealso{ +{ +\code{\link{sql_install.packages}} to install packages + +\code{\link{sql_installed.packages}} to enumerate the installed packages + +\code{\link{remove.packages}} for the base version of this function + +} +} diff --git a/R/sqlmlutils.Rproj b/R/sqlmlutils.Rproj new file mode 100644 index 0000000..f54a4b6 --- /dev/null +++ b/R/sqlmlutils.Rproj @@ -0,0 +1,21 @@ +Version: 1.0 + +RestoreWorkspace: No +SaveWorkspace: No +AlwaysSaveHistory: Default + +EnableCodeIndexing: Yes +UseSpacesForTab: Yes +NumSpacesForTab: 4 +Encoding: UTF-8 + +RnwWeave: Sweave +LaTeX: pdfLaTeX + +AutoAppendNewline: Yes +StripTrailingWhitespace: Yes + +BuildType: Package +PackageUseDevtools: Yes +PackageInstallArgs: --no-multiarch --with-keep.source +PackageRoxygenize: rd,collate,namespace diff --git a/R/tests/testthat.R b/R/tests/testthat.R new file mode 100644 index 0000000..e42933a --- /dev/null +++ b/R/tests/testthat.R @@ -0,0 +1,7 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +library(testthat) +library(sqlmlutils) + +test_check("sqlmlutils") diff --git a/R/tests/testthat/helper-Setup.R b/R/tests/testthat/helper-Setup.R new file mode 100644 index 0000000..dbcf412 --- /dev/null +++ b/R/tests/testthat/helper-Setup.R @@ -0,0 +1,27 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +library(sqlmlutils) +library(methods) +library(testthat) + +options(keep.source = TRUE) +Sys.setenv(TZ='GMT') +Server <- '' +if (Server == '') Server <- "." +cnnstr <- connectionInfo(server=Server, database="AirlineTestDB") + +testthatDir <- getwd() +R_Root <- file.path(testthatDir, "../..") +scriptDirectory <- file.path(testthatDir, "scripts") + +TestArgs <- list( + # Compute context specifications + gitRoot = R_Root, + testDirectory = testthatDir, + scriptDirectory = scriptDirectory, + connectionString = cnnstr +) + +options(TestArgs = TestArgs) +rm(TestArgs) diff --git a/R/tests/testthat/scripts/script.txt b/R/tests/testthat/scripts/script.txt new file mode 100644 index 0000000..e6bf20b --- /dev/null +++ b/R/tests/testthat/scripts/script.txt @@ -0,0 +1,7 @@ +foo <- function(t1, t2, t3) { + print(t1) + warning(t2) + return(t3) +} + +foo("Hello","WARNING", InputDataSet) diff --git a/R/tests/testthat/scripts/script2.txt b/R/tests/testthat/scripts/script2.txt new file mode 100644 index 0000000..500149a --- /dev/null +++ b/R/tests/testthat/scripts/script2.txt @@ -0,0 +1,5 @@ + +sum1 <- 1+2 +sum2 <- 5+6 +product <- sum1 * sum2 +product diff --git a/R/tests/testthat/scripts/script3.R b/R/tests/testthat/scripts/script3.R new file mode 100644 index 0000000..02fdba3 --- /dev/null +++ b/R/tests/testthat/scripts/script3.R @@ -0,0 +1,2 @@ +product <- num1 * num2 +out_df <- rbind(in_df, product) diff --git a/R/tests/testthat/test.executeInSqlTests.R b/R/tests/testthat/test.executeInSqlTests.R new file mode 100644 index 0000000..cad71fa --- /dev/null +++ b/R/tests/testthat/test.executeInSqlTests.R @@ -0,0 +1,203 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +library(testthat) +context("executeInSQL tests") + +TestArgs <- options("TestArgs")$TestArgs +connection <- TestArgs$connectionString +scriptDir <- TestArgs$scriptDirectory + + + +test_that("Test with named args", { + funcWithArgs <- function(arg1, arg2){ + print(arg1) + return(arg2) + } + expect_output( + expect_equal( + executeFunctionInSQL(connection, funcWithArgs, arg1="blah1", arg2="blah2"), + "blah2"), + "blah1" + ) +}) + +test_that("Test ordered arguments", { + funcNum <- function(arg1, arg2){ + stopifnot(typeof(arg1) == "integer") + stopifnot(typeof(arg2) == "double") + return(arg1 / arg2) + } + expect_error(executeFunctionInSQL(connection, funcNum, 2)) + expect_equal(executeFunctionInSQL(connection, funcNum, as.integer(2), 3), 2/3) + expect_equal(executeFunctionInSQL(connection, funcNum, as.integer(3), 2), 3/2) +}) + +test_that("Test Return", { + myReturnVal <- function(){ + return("returned!") + } + + val = executeFunctionInSQL(connection, myReturnVal) + expect_equal(val, myReturnVal()) +}) + +test_that("Test Warning", { + printWarning <- function(){ + warning("testWarning") + print("Hello, this returned") + } + expect_warning( + expect_output(executeFunctionInSQL(connection, printWarning), + "Hello, this returned"), + "testWarning") + +}) + +test_that("Passing in a user defined function", { + func1 <- function(){ + func2 <- function() { + return("Success") + } + return(func2()) + } + + expect_equal(executeFunctionInSQL(connection, func=func1), "Success") +}) + +test_that("Returning a function object", { + func2 <- function() { + return("Success") + } + func1 <- function(){ + func2 <- function() { + return("Success") + } + return(func2) + } + + expect_equal(executeFunctionInSQL(connection, func=func1), func2) +}) + +test_that("Calling an object in the environment", { + skip("This doesn't work right now because we don't pass the whole environment") + + func2 <- function() { + return("Success") + } + func1 <- function(){ + return(func2) + } + + expect_equal(executeFunctionInSQL(connection, func=func1), func2) +}) + +test_that("No Parameters test", { + noReturn <- function() { + } + result = executeFunctionInSQL(connection, noReturn) + expect_null(result) +}) + +test_that("Print, Warning, Return test", { + + returnString <- function() { + print("hello") + warning("uh oh") + return("bar") + } + expect_warning(expect_output(result <- executeFunctionInSQL(connection, returnString), "hello"), "uh oh") + + expect_equal(result , "bar") +}) + +test_that("Print, Warning, Return test, with args", { + + returnVector <- function(a,b) { + print("print") + warning("uh oh") + return(c(a,b)) + } + expect_warning(expect_output(result <- executeFunctionInSQL(connection, returnVector, "foo", "bar"), "print"), "uh oh") + + expect_equal(result , c("foo","bar")) +}) + +test_that("Print, Warning, Error test", { + testError <- function() { + print("print") + warning("warning") + stop("ERROR") + } + expect_error( + expect_warning( + expect_output( + result <- executeFunctionInSQL(connection, testError), + "print"), + "warning"), + "ERROR") +}) + +test_that("Return a DataFrame", { + + returnDF <- function(a, b) { + return(data.frame(x = c(foo=a,bar=b))) + } + result <- executeFunctionInSQL(connection, returnDF, "foo", 2) + expect_equal(result, data.frame(x = c(foo="foo",bar=2))) +}) + +test_that("Return an input DataFrame", { + useInputDataSet <- function(in_df) { + return(in_df) + } + result = executeFunctionInSQL(connection, useInputDataSet, inputDataQuery = "SELECT TOP 5 * FROM airline5000") + expect_equal(nrow(result), 5) + expect_equal(ncol(result), 30) + + useInputDataSet2 <- function(in_df, t1) { + return(list(in_df, t1=t1)) + } + result = executeFunctionInSQL(connection, useInputDataSet2, t1=5, inputDataQuery = "SELECT TOP 5 * FROM airline5000") + expect_equal(result$t1, 5) + expect_equal(ncol(result[[1]]), 30) + +}) + +test_that("Variable test", { + + printString <- function(str) { + print(str) + } + expect_output(executeFunctionInSQL(connection, printString, str="Hello"), "Hello") + test <- "World" + expect_output(executeFunctionInSQL(connection, printString, str=test), test) +}) + +test_that("Query test", { + res <- executeSQLQuery(connectionString = connection, sqlQuery = "SELECT TOP 5 * FROM airline5000") + expect_equal(nrow(res), 5) + expect_equal(ncol(res), 30) +}) + +test_that("Script test", { + script <- file.path(scriptDir, 'script.txt') + + expect_warning( + expect_output( + res <- executeScriptInSQL(connectionString=connection, script=script, inputDataQuery = "SELECT TOP 5 * FROM airline5000"), + "Hello"), + "WARNING") + expect_equal(nrow(res), 5) + expect_equal(ncol(res), 30) + + script2 <- file.path(scriptDir, 'script2.txt') + + + expect_output(res <- executeScriptInSQL(connection, script2), "Script path exists") + expect_equal(res, 33) + + expect_error(res <- executeScriptInSQL(connection, "non-existent-script.txt"), regexp = "Script path doesn't exist") + +}) diff --git a/R/tests/testthat/test.storedProcedureTests.R b/R/tests/testthat/test.storedProcedureTests.R new file mode 100644 index 0000000..9c7afec --- /dev/null +++ b/R/tests/testthat/test.storedProcedureTests.R @@ -0,0 +1,319 @@ +# Copyright(c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +library(testthat) +context("Stored Procedure tests") + +TestArgs <- options('TestArgs')$TestArgs +connection <- TestArgs$connectionString +scriptDir <- TestArgs$scriptDirectory + +dropIfExists <- function(connectionString, name) { + if(checkSproc(connectionString, name)) + invisible(capture.output(dropSproc(connectionString = connectionString, name = name))) +} + +# +#Test an empty function (no inputs) +test_that("No Parameters test", { + noParams <- function() { + data.frame(hello = "world") + } + name = "noParams" + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromFunction(name, noParams, connectionString = connection)) + expect_true(checkSproc(name, connectionString = connection)) + + expect_equal(as.character(executeSproc(connectionString = connection, name)[[1]]) , "world") + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) +}) + +# +#Test multiple input parameters +#("posixct", "numeric", "character", "integer", "logical", "raw", "dataframe") +test_that("Numeric, POSIXct, Character, Logical test", { + inNumCharParams <- function(in1, in2, in3, in4) { + data.frame(in1, in2,in3,in4) + } + + #TODO: Time zone might not work + x <- as.POSIXct(12345678, origin = "1960-01-01")#, tz= "GMT") + + inputParams <- list(in1="numeric", in2="posixct", in3="character", in4="logical") + + name = "inNumCharParams" + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromFunction(name, inNumCharParams, connectionString = connection, inputParams = inputParams)) + expect_true(checkSproc(name, connectionString = connection)) + + res <- executeSproc(name, in1 = 1, in2 = x, in3 = "Hello", in4 = 1, connectionString = connection) + + expect_equal(res[[1]], 1) + expect_equal(res[[2]], x) + expect_equal(as.character(res[[3]]), "Hello") + expect_equal(as.logical(res[[4]]), TRUE) + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) +}) + +# +#Test only an InputDataSet StoredProcedure +test_that("Simple InputDataSet test", { + inData <- function(in_df) { + in_df + } + + inputParams <- list(in_df="dataframe") + + name = "inData" + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromFunction(name, inData, connectionString = connection, inputParams = inputParams)) + expect_true(checkSproc(name, connectionString = connection)) + + res <- executeSproc(name, in_df = "SELECT TOP 10 * FROM airline5000", connectionString = connection) + expect_equal(nrow(res), 10) + expect_equal(ncol(res), 30) + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) +}) + + +# +#Test InputDataSet with returned OutputDataSet +test_that("InputDataSet to OutputDataSet test", { + inOutData <- function(in_df) { + list(out_df = in_df) + } + + inputParams <- list(in_df="dataframe") + outputParams <- list(out_df="dataframe") + + name = "inOutData" + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromFunction(name, inOutData, connectionString = connection, inputParams = inputParams, outputParams = outputParams)) + expect_true(checkSproc(name, connectionString = connection)) + + res <- executeSproc(name, in_df = "SELECT TOP 10 * FROM airline5000", connectionString = connection) + expect_equal(nrow(res), 10) + expect_equal(ncol(res), 30) + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) +}) + +# +#Test InputDataSet query with InputParameters +test_that("InputDataSet with InputParameter test", { + inDataParams <- function(id, ip) { + rbind(id,ip) + } + + name = "inDataParams" + + inputParams = list(id = "DataFrame", ip = "numeric") + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromFunction(name, inDataParams, connectionString = connection, inputParams = inputParams)) + expect_true(checkSproc(name, connectionString = connection)) + + res <- executeSproc(name, id = "SELECT TOP 10 * FROM airline5000", ip = 4, connectionString = connection) + + expect_equal(nrow(res), 11) + expect_equal(ncol(res), 30) + + expect_error(executeSproc(name, "SELECT TOP 10 * FROM airline5000", ip = 4, connectionString = connection)) + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) +}) + + +# +#Test InputDataSet query with InputParameters with inputs out of order +test_that("InputDataSet with InputParameter test, out of order", { + inDataParams <- function(id, ip, ip2) { + rbind(id,ip) + } + + name = "inDataParamsOoO" + + inputParams = list(ip = "numeric", id = "DATAFRAME", ip2 = "character") + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromFunction(name, inDataParams, connectionString = connection, inputParams = inputParams)) + expect_true(checkSproc(name, connectionString = connection)) + + res <- executeSproc(name, ip2 = "Hello", ip = 4, id = "SELECT TOP 10 * FROM airline5000", connectionString = connection) + + expect_equal(nrow(res), 11) + expect_equal(ncol(res), 30) + + expect_error(executeSproc(name,ip = 4, "SELECT TOP 10 * FROM airline5000", connectionString = connection)) + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) +}) + + +test_that("Stored Procedure with Scripts", { + inputParams <- list(num1 = "numeric", num2 = "numeric", in_df = "dAtaFrame") + outputParams <- list(out_df = "dataframe") + + name="script" + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromScript( + connectionString = connection, name=name, file.path(scriptDir, "script3.R"), inputParams = inputParams, outputParams = outputParams)) + expect_true(checkSproc(connectionString = connection, name = name)) + + retVal <- executeSproc(connectionString = connection, name, num1 = 3, num2 = 4, in_df = "select top 10 * from airline5000") + + expect_equal(nrow(retVal), 11) + expect_equal(ncol(retVal), 30) + + dropIfExists(connectionString = connection, name = name) + expect_false(checkSproc(connectionString = connection, name = name)) +}) + +context("Sprocs with output params (TODO)") + +# TODO: Output params test - execution doesn't work right now +test_that("Only OuputParams test", { + outsFunc <- function(arg1) { + list(res = paste0("Hello ", arg1, "!")) + } + + name <- "outsFunc" + inputParams <- list(arg1 = "character") + outputParams <- list(res = "character") + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + capture.output(createSprocFromFunction(name, outsFunc, connectionString = connection, inputParams = inputParams, outputParams = outputParams)) + expect_true(checkSproc(name, connectionString = connection)) + + + #Use T-SQL to verify + sql_str = "DECLARE @res nvarchar(max) EXEC outsFunc @arg1_outer = N'T-SQL', @res_outer = @res OUTPUT SELECT @res as N'@res'" + out <- system2("sqlcmd.exe", c("-S", "localhost", "-E", "-d","AirlineTestDB", "-Q", paste0('"', sql_str, '"')), stdout=TRUE) + expect_true(any(grepl("Hello T-SQL!", out))) + #executeSproc(name, connectionString = connection, out1 = "Asd", out2 = "wqe") + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) +}) + + +test_that("OutputDataSet and OuputParams test", { + outDataParam <- function() { + df = data.frame(hello = "world") + list(df = df, out1 = "Hello", out2 = "World") + } + name = "outDataParam" + + outputParams <- list(df = "dataframe", out2 = "character", out1 = "character") + + createSprocFromFunction(name, outDataParam, connectionString = connection, outputParams = outputParams) + stopifnot(checkSproc(name, connectionString = connection)) + + #Use T-SQL to verify + sql_str = "DECLARE @out1 nvarchar(max),@out2 nvarchar(max) EXEC outDataParam @out1_outer = @out1 OUTPUT, @out2_outer = @out2 OUTPUT SELECT @out1 as N'@out1'" + out <- system2("sqlcmd.exe", c("-S", "localhost", "-E", "-d","AirlineTestDB", "-Q", paste0('"', sql_str, '"')), stdout=TRUE) + expect_true(any(grepl("Hello", out))) + #res <- executeSproc(connectionString = connection, name) + + dropSproc(name, connectionString = connection) + stopifnot(!checkSproc(name, connectionString = connection)) +}) + +context("Sproc Negative Tests") + +test_that("Bad input param types or usage", { + badParam <- function(arg1) { + return(arg1) + } + inputParams <- list(arg1 = "NotAType") + + expect_error(createSprocFromFunction(connection, "badParam", badParam, inputParams = inputParams)) + + inputParams <- list(arg1 = "dataframe") + + name = "badInput" + dropIfExists(connection, name) + capture.output(createSprocFromFunction(connection, name, badParam, inputParams = inputParams)) + expect_true(checkSproc(connection, name)) + + expect_error(expect_warning(executeSproc(connection, name, arg1=12314532))) + res <- executeSproc(connection, name, arg1="SELECT TOP 5 * FROM airline5000") + + expect_equal(ncol(res), 30) + expect_equal(nrow(res), 5) + dropIfExists(connection, name) +}) + +test_that("Drop nonexistent sproc",{ + expect_false(checkSproc(connection, "NonexistentSproc")) + expect_output(dropSproc(connection, "NonexistentSproc"), "Named procedure doesn't exist") +}) + +test_that("Create with bad name",{ + name = "'''asd''asd''asd" + foo = function() { + return(NULL) + } + expect_error(createSprocFromFunction(connection, name, foo)) +}) + +test_that("mismatch input params", { + func <- function(arg1, arg2) { + return(arg1) + } + inputParams <- list(arg1 = "dataframe", arg3 = "numeric") + + dropIfExists(connection, "mismatch") + expect_error(createSprocFromFunction(connection, "mismatch", func, inputParams = inputParams)) + + inputParams <- list(arg1 = "dataframe", arg2 = "qwe", arg3 = "numeric") + + dropIfExists(connection, "mismatch") + expect_error(createSprocFromFunction(connection, "mismatch", func, inputParams = inputParams)) + +}) + + +test_that("Sproc with Bad Script Path", { + name="bad_script_path" + + dropIfExists(name, connectionString = connection) + expect_false(checkSproc(name, connectionString = connection)) + + expect_error(createSprocFromScript( + connectionString = connection, name=name, "bad_script_path.txt")) + +}) + + + diff --git a/README.md b/README.md index 72f1506..eb4ae10 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,38 @@ +# sqlmlutils -# Contributing +sqlmlutils is a package designed to help users interact with SQL Server and execute R or Python code from an R/Python client. -This project welcomes contributions and suggestions. Most contributions require you to agree to a -Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us -the rights to use your contribution. For details, visit https://cla.microsoft.com. +# Installation -When you submit a pull request, a CLA-bot will automatically determine whether you need to provide -a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions -provided by the bot. You will only need to do this once across all repos using our CLA. +To install sqlmlutils from this repository, run the following commands from the root folder: -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). -For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or -contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. +Python: +1. If your client is a Linux machine, you can skip this step. If your client is a Windows machine: go to https://www.lfd.uci.edu/~gohlke/pythonlibs/#pymssql and download the correct version of pymssql for your client. Run ```pip install pymssql-2.1.4.dev5-cpXX-cpXXm-win_amd64.whl``` on that file to install pymssql. +2. Run +``` +python.exe -m pip install Python/dist/sqlmlutils-0.5.0.zip --upgrade +``` + +R: +``` +R CMD INSTALL R/dist/sqlmlutils_0.5.0.zip +``` + +# Details + +sqlmlutils contains 3 main parts: +- Execution of Python/R in SQL Server using sp_execute_external_script +- Creation and execution of stored procedures created from scripts and functions +- Install and manage packages in SQL Server + +## Execute in SQL + +Execute in SQL provides a convenient way for the user to execute arbitrary Python/R code inside a sql server using an sp_execute_external_script. The user does not have to know any t-sql to use this function. Function arguments are serialized into binary and passed into the t-sql script that is generated. Warnings and printed output will be printed at the end of execution, and any results returned by the function will be passed back to the client. + +## Stored Procedures (Sprocs) + +The goal of this utility is to allow users to create and execute stored procedures on their database without needing to know the exact syntax of creating one. Functions and scripts are wrapped into a stored procedure and registered into a database, then can be executed from the Python/R client. + +## Package Management + +With package management users can install packages to a remote SQL server from a client machine. The packages are downloaded on the client and then send over to SQL server where they will be installed into library folders. The folders are per-database so packages will always be installed and made available for a specific database. The package management APIs provided a PUBLIC and PRIVATE folders. Packages in the PUBLIC folder are accessible to all database users. Packages in the PRIVATE folder are only accessible by the user who installed the package.