Change pymssql -> pyODBC (#58)
Major changes to sqlmlutils internal implementation. User experience should be largely the same as before, except with the addition of output parameters. * Change to pyodbc instead of mssql * Add print capture to exec * Add Output Param support and stdout/err in sprocs * change testing for python package management * README updates * Fix license headers * Update version number to 1.0.0
This commit is contained in:
Родитель
1c82b56e68
Коммит
dade183979
|
@ -3,7 +3,7 @@ sqlmlutils
|
|||
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
|
491
Python/README.md
491
Python/README.md
|
@ -1,243 +1,248 @@
|
|||
# sqlmlutils
|
||||
|
||||
sqlmlutils is a python package to help execute Python code on a SQL Server machine. It is built to work with ML Services for SQL Server.
|
||||
|
||||
# Installation
|
||||
|
||||
Download the zip package file from the dist folder.
|
||||
From a command prompt, run
|
||||
```
|
||||
pip install sqlmlutils
|
||||
```
|
||||
|
||||
Note: If you encounter errors installing the pymssql dependency and your client is a Windows machine, consider
|
||||
installing the .whl file at the below link (download the file for your Python version and run pip install):
|
||||
https://www.lfd.uci.edu/~gohlke/pythonlibs/#pymssql
|
||||
|
||||
If you are developing on your own branch and want to rebuild and install the package, you can use the buildandinstall.cmd script that is included.
|
||||
|
||||
# Getting started
|
||||
|
||||
Shown below are the important functions sqlmlutils provides:
|
||||
```python
|
||||
execute_function_in_sql # Execute a python function inside the SQL database
|
||||
execute_script_in_sql # Execute a python script inside the SQL database
|
||||
execute_sql_query # Execute a sql query in the database and return the resultant table
|
||||
|
||||
create_sproc_from_function # Create a stored procedure based on a Python function inside the SQL database
|
||||
create_sproc_from_script # Create a stored procedure based on a Python script inside the SQL database
|
||||
check_sproc # Check whether a stored procedure exists in the SQL database
|
||||
drop_sproc # Drop a stored procedure from the SQL database
|
||||
execute_sproc # Execute a stored procedure in the SQL database
|
||||
|
||||
install_package # Install a Python package on the SQL database
|
||||
remove_package # Remove a Python package from the SQL database
|
||||
list # Enumerate packages that are installed on the SQL database
|
||||
get_packages_by_user # Enumerate external libraries installed by specific user in specific scope
|
||||
```
|
||||
|
||||
# Examples
|
||||
|
||||
### Execute in SQL
|
||||
##### Execute a python function in database
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
|
||||
def foo():
|
||||
return "bar"
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="master", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="master")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
result = sqlpy.execute_function_in_sql(foo)
|
||||
assert result == "bar"
|
||||
```
|
||||
|
||||
##### Generate a scatter plot without the data leaving the machine
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def scatter_plot(input_df, x_col, y_col):
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
title = x_col + " vs. " + y_col
|
||||
|
||||
plt.scatter(input_df[x_col], input_df[y_col])
|
||||
plt.xlabel(x_col)
|
||||
plt.ylabel(y_col)
|
||||
plt.title(title)
|
||||
|
||||
# Save scatter plot image as a png
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
|
||||
# Returns the bytes of the png to the client
|
||||
return buf
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
|
||||
sql_query = "select top 100 * from airline5000"
|
||||
plot_data = sqlpy.execute_function_in_sql(func=scatter_plot, input_data_query=sql_query,
|
||||
x_col="ArrDelay", y_col="CRSDepTime")
|
||||
im = Image.open(plot_data)
|
||||
im.show()
|
||||
```
|
||||
|
||||
##### Perform linear regression on data stored in SQL Server without the data leaving the machine
|
||||
|
||||
You can use the AirlineTestDB (supplied as a .bak file above) to run these examples.
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
|
||||
def linear_regression(input_df, x_col, y_col):
|
||||
from sklearn import linear_model
|
||||
|
||||
X = input_df[[x_col]]
|
||||
y = input_df[y_col]
|
||||
|
||||
lr = linear_model.LinearRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
return lr
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
sql_query = "select top 1000 CRSDepTime, CRSArrTime from airline5000"
|
||||
regression_model = sqlpy.execute_function_in_sql(linear_regression, input_data_query=sql_query,
|
||||
x_col="CRSDepTime", y_col="CRSArrTime")
|
||||
print(regression_model)
|
||||
print(regression_model.coef_)
|
||||
```
|
||||
|
||||
##### Execute a SQL Query from Python
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
import pytest
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
sql_query = "select top 10 * from airline5000"
|
||||
data_table = sqlpy.execute_sql_query(sql_query)
|
||||
assert len(data_table.columns) == 30
|
||||
assert len(data_table) == 10
|
||||
```
|
||||
|
||||
### Stored Procedure
|
||||
##### Create and call a T-SQL stored procedure based on a Python function
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
import pytest
|
||||
|
||||
def principal_components(input_table: str, output_table: str):
|
||||
import sqlalchemy
|
||||
from urllib import parse
|
||||
import pandas as pd
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
# Internal ODBC connection string used by process executing inside SQL Server
|
||||
connection_string = "Driver=SQL Server;Server=localhost;Database=AirlineTestDB;Trusted_Connection=Yes;"
|
||||
engine = sqlalchemy.create_engine("mssql+pyodbc:///?odbc_connect={}".format(parse.quote_plus(connection_string)))
|
||||
|
||||
input_df = pd.read_sql("select top 200 ArrDelay, CRSDepTime from {}".format(input_table), engine).dropna()
|
||||
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
components = pca.fit_transform(input_df)
|
||||
|
||||
output_df = pd.DataFrame(components)
|
||||
output_df.to_sql(output_table, engine, if_exists="replace")
|
||||
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
input_table = "airline5000"
|
||||
output_table = "AirlineDemoPrincipalComponents"
|
||||
|
||||
sp_name = "SavePrincipalComponents"
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
|
||||
if sqlpy.check_sproc(sp_name):
|
||||
sqlpy.drop_sproc(sp_name)
|
||||
|
||||
sqlpy.create_sproc_from_function(sp_name, principal_components)
|
||||
|
||||
# You can check the stored procedure exists in the db with this:
|
||||
assert sqlpy.check_sproc(sp_name)
|
||||
|
||||
sqlpy.execute_sproc(sp_name, input_table=input_table, output_table=output_table)
|
||||
|
||||
sqlpy.drop_sproc(sp_name)
|
||||
assert not sqlpy.check_sproc(sp_name)
|
||||
```
|
||||
|
||||
### Package Management
|
||||
|
||||
##### Package management with sqlmlutils is supported in SQL Server 2019 CTP 2.4 and later.
|
||||
|
||||
##### Install and remove packages from SQL Server
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
pkgmanager = sqlmlutils.SQLPackageManager(connection)
|
||||
pkgmanager.install("astor")
|
||||
|
||||
def import_astor():
|
||||
import astor
|
||||
|
||||
# import the astor package to make sure it installed properly
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
val = sqlpy.execute_function_in_sql(import_astor)
|
||||
|
||||
pkgmanager.uninstall("astor")
|
||||
```
|
||||
|
||||
# Notes for Developers
|
||||
|
||||
### Running the tests
|
||||
|
||||
1. Make sure a SQL Server with an updated ML Services Python is running on localhost.
|
||||
2. Restore the AirlineTestDB from the .bak file in this repo
|
||||
3. Make sure Trusted (Windows) authentication works for connecting to the database
|
||||
4. Setup a user with db_owner role with uid: "Tester" and password "FakeT3sterPwd!"
|
||||
|
||||
### Notable TODOs and open issues
|
||||
|
||||
1. The pymssql library is hard to install. Users need to install the .whl files from the link above, not
|
||||
the .whl files currently hosted in PyPI. Because of this, we should consider moving to use pyodbc.
|
||||
2. Testing from a Linux client has not been performed.
|
||||
3. The way we get dependencies of a package to install is sort of hacky (parsing pip output)
|
||||
4. Output Parameter execution currently does not work - can potentially use MSSQLStoredProcedure binding
|
||||
# sqlmlutils
|
||||
|
||||
sqlmlutils is a python package to help execute Python code on a SQL Server machine. It is built to work with ML Services for SQL Server.
|
||||
|
||||
# Installation
|
||||
|
||||
To install from PyPI, run:
|
||||
```
|
||||
pip install sqlmlutils
|
||||
```
|
||||
To install from file, run:
|
||||
```
|
||||
pip install Python/dist/sqlmlutils-1.0.0.zip
|
||||
```
|
||||
|
||||
|
||||
|
||||
Note: If you encounter errors installing the pymssql dependency and your client is a Windows machine, consider
|
||||
installing the .whl file at the below link (download the file for your Python version and run pip install):
|
||||
https://www.lfd.uci.edu/~gohlke/pythonlibs/#pymssql
|
||||
|
||||
If you are developing on your own branch and want to rebuild and install the package, you can use the buildandinstall.cmd script that is included.
|
||||
|
||||
# Getting started
|
||||
|
||||
Shown below are the important functions sqlmlutils provides:
|
||||
```python
|
||||
execute_function_in_sql # Execute a python function inside the SQL database
|
||||
execute_script_in_sql # Execute a python script inside the SQL database
|
||||
execute_sql_query # Execute a sql query in the database and return the resultant table
|
||||
|
||||
create_sproc_from_function # Create a stored procedure based on a Python function inside the SQL database
|
||||
create_sproc_from_script # Create a stored procedure based on a Python script inside the SQL database
|
||||
check_sproc # Check whether a stored procedure exists in the SQL database
|
||||
drop_sproc # Drop a stored procedure from the SQL database
|
||||
execute_sproc # Execute a stored procedure in the SQL database
|
||||
|
||||
install_package # Install a Python package on the SQL database
|
||||
remove_package # Remove a Python package from the SQL database
|
||||
list # Enumerate packages that are installed on the SQL database
|
||||
get_packages_by_user # Enumerate external libraries installed by specific user in specific scope
|
||||
```
|
||||
|
||||
# Examples
|
||||
|
||||
### Execute in SQL
|
||||
##### Execute a python function in database
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
|
||||
def foo():
|
||||
return "bar"
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="master", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="master")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
result = sqlpy.execute_function_in_sql(foo)
|
||||
assert result == "bar"
|
||||
```
|
||||
|
||||
##### Generate a scatter plot without the data leaving the machine
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def scatter_plot(input_df, x_col, y_col):
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
title = x_col + " vs. " + y_col
|
||||
|
||||
plt.scatter(input_df[x_col], input_df[y_col])
|
||||
plt.xlabel(x_col)
|
||||
plt.ylabel(y_col)
|
||||
plt.title(title)
|
||||
|
||||
# Save scatter plot image as a png
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
|
||||
# Returns the bytes of the png to the client
|
||||
return buf
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
|
||||
sql_query = "select top 100 * from airline5000"
|
||||
plot_data = sqlpy.execute_function_in_sql(func=scatter_plot, input_data_query=sql_query,
|
||||
x_col="ArrDelay", y_col="CRSDepTime")
|
||||
im = Image.open(plot_data)
|
||||
im.show()
|
||||
```
|
||||
|
||||
##### Perform linear regression on data stored in SQL Server without the data leaving the machine
|
||||
|
||||
You can use the AirlineTestDB (supplied as a .bak file above) to run these examples.
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
|
||||
def linear_regression(input_df, x_col, y_col):
|
||||
from sklearn import linear_model
|
||||
|
||||
X = input_df[[x_col]]
|
||||
y = input_df[y_col]
|
||||
|
||||
lr = linear_model.LinearRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
return lr
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
sql_query = "select top 1000 CRSDepTime, CRSArrTime from airline5000"
|
||||
regression_model = sqlpy.execute_function_in_sql(linear_regression, input_data_query=sql_query,
|
||||
x_col="CRSDepTime", y_col="CRSArrTime")
|
||||
print(regression_model)
|
||||
print(regression_model.coef_)
|
||||
```
|
||||
|
||||
##### Execute a SQL Query from Python
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
import pytest
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
sql_query = "select top 10 * from airline5000"
|
||||
data_table = sqlpy.execute_sql_query(sql_query)
|
||||
assert len(data_table.columns) == 30
|
||||
assert len(data_table) == 10
|
||||
```
|
||||
|
||||
### Stored Procedure
|
||||
##### Create and call a T-SQL stored procedure based on a Python function
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
import pytest
|
||||
|
||||
def principal_components(input_table: str, output_table: str):
|
||||
import sqlalchemy
|
||||
from urllib import parse
|
||||
import pandas as pd
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
# Internal ODBC connection string used by process executing inside SQL Server
|
||||
connection_string = "Driver=SQL Server;Server=localhost;Database=AirlineTestDB;Trusted_Connection=Yes;"
|
||||
engine = sqlalchemy.create_engine("mssql+pyodbc:///?odbc_connect={}".format(parse.quote_plus(connection_string)))
|
||||
|
||||
input_df = pd.read_sql("select top 200 ArrDelay, CRSDepTime from {}".format(input_table), engine).dropna()
|
||||
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
components = pca.fit_transform(input_df)
|
||||
|
||||
output_df = pd.DataFrame(components)
|
||||
output_df.to_sql(output_table, engine, if_exists="replace")
|
||||
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
|
||||
input_table = "airline5000"
|
||||
output_table = "AirlineDemoPrincipalComponents"
|
||||
|
||||
sp_name = "SavePrincipalComponents"
|
||||
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
|
||||
if sqlpy.check_sproc(sp_name):
|
||||
sqlpy.drop_sproc(sp_name)
|
||||
|
||||
sqlpy.create_sproc_from_function(sp_name, principal_components)
|
||||
|
||||
# You can check the stored procedure exists in the db with this:
|
||||
assert sqlpy.check_sproc(sp_name)
|
||||
|
||||
sqlpy.execute_sproc(sp_name, input_table=input_table, output_table=output_table)
|
||||
|
||||
sqlpy.drop_sproc(sp_name)
|
||||
assert not sqlpy.check_sproc(sp_name)
|
||||
```
|
||||
|
||||
### Package Management
|
||||
|
||||
##### Python package management with sqlmlutils is supported in SQL Server 2019 CTP 2.4 and later.
|
||||
|
||||
##### Install and remove packages from SQL Server
|
||||
|
||||
```python
|
||||
import sqlmlutils
|
||||
|
||||
# For Linux SQL Server, you must specify the ODBC Driver and the username/password because there is no Trusted_Connection/Implied Authentication support yet.
|
||||
# connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 13 for SQL Server", server="localhost", database="AirlineTestDB", uid="username", pwd="password")
|
||||
|
||||
connection = sqlmlutils.ConnectionInfo(server="localhost", database="AirlineTestDB")
|
||||
pkgmanager = sqlmlutils.SQLPackageManager(connection)
|
||||
pkgmanager.install("astor")
|
||||
|
||||
def import_astor():
|
||||
import astor
|
||||
|
||||
# import the astor package to make sure it installed properly
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
val = sqlpy.execute_function_in_sql(import_astor)
|
||||
|
||||
pkgmanager.uninstall("astor")
|
||||
```
|
||||
|
||||
# Notes for Developers
|
||||
|
||||
### Running the tests
|
||||
|
||||
1. Make sure a SQL Server with an updated ML Services Python is running on localhost.
|
||||
2. Restore the AirlineTestDB from the .bak file in this repo
|
||||
3. Make sure Trusted (Windows) authentication works for connecting to the database
|
||||
4. Setup a user with db_owner role (and not server admin) with uid: "Tester" and password "FakeT3sterPwd!"
|
||||
|
||||
### Notable TODOs and open issues
|
||||
|
||||
1. The pymssql library is hard to install. Users need to install the .whl files from the link above, not
|
||||
the .whl files currently hosted in PyPI. Because of this, we should consider moving to use pyodbc.
|
||||
2. Testing from a Linux client has not been performed.
|
||||
3. The way we get dependencies of a package to install is sort of hacky (parsing pip output)
|
||||
4. Output Parameter execution currently does not work - can potentially use MSSQLStoredProcedure binding
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
python.exe setup.py sdist --formats=zip
|
||||
python.exe -m pip install --upgrade --upgrade-strategy only-if-needed dist\sqlmlutils-0.7.3.zip
|
||||
python.exe -m pip install --upgrade --upgrade-strategy only-if-needed dist\sqlmlutils-1.0.0.zip
|
Двоичный файл не отображается.
|
@ -1,5 +1,5 @@
|
|||
pip>=9.0.1
|
||||
pymssql>=2.1.4,<3.0
|
||||
pyodbc>=4.0.25
|
||||
dill>=0.2.6
|
||||
pkginfo>=1.4.2
|
||||
requirements-parser>=0.2.0
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from setuptools import setup
|
||||
|
@ -6,7 +6,7 @@ from setuptools import setup
|
|||
setup(
|
||||
name='sqlmlutils',
|
||||
packages=['sqlmlutils', 'sqlmlutils/packagemanagement'],
|
||||
version='0.7.3',
|
||||
version='1.0.0',
|
||||
url='https://github.com/Microsoft/sqlmlutils/Python',
|
||||
license='MIT License',
|
||||
desciption='A client side package for working with SQL Server',
|
||||
|
@ -16,7 +16,7 @@ setup(
|
|||
author_email='joz@microsoft.com',
|
||||
install_requires=[
|
||||
'pip',
|
||||
'pymssql',
|
||||
'pyodbc',
|
||||
'dill',
|
||||
'pkginfo',
|
||||
'requirements-parser',
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .connectioninfo import ConnectionInfo
|
||||
from .sqlpythonexecutor import SQLPythonExecutor
|
||||
from .packagemanagement.sqlpackagemanager import SQLPackageManager
|
||||
from .packagemanagement.scope import Scope
|
||||
from .packagemanagement.sqlpackagemanager import SQLPackageManager
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
class ConnectionInfo:
|
||||
|
@ -52,10 +52,15 @@ class ConnectionInfo:
|
|||
|
||||
@property
|
||||
def connection_string(self):
|
||||
server = self._server if self._port == "" \
|
||||
else "{server},{port}".format(server=self._server, port=self._port)
|
||||
|
||||
auth = "Trusted_Connection=Yes" if self._uid == "" \
|
||||
else "uid={uid};pwd={pwd}".format(uid=self._uid, pwd=self._pwd)
|
||||
|
||||
return "Driver={driver};Server={server};Database={database};{auth};".format(
|
||||
driver=self._driver,
|
||||
server=self._server if self._port == "" else "{servername},{port}".format(servername=self._server, port=self._port),
|
||||
database=self._database,
|
||||
auth="Trusted_Connection=Yes" if self._uid == "" else
|
||||
"uid={uid};pwd={pwd}".format(uid=self._uid, pwd=self._pwd)
|
||||
)
|
||||
driver = self._driver,
|
||||
server = server,
|
||||
database = self._database,
|
||||
auth = auth
|
||||
)
|
|
@ -1 +1,4 @@
|
|||
from .sqlpackagemanager import SQLPackageManager
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .sqlpackagemanager import SQLPackageManager
|
|
@ -1,11 +1,10 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import operator
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
|
||||
class DependencyResolver:
|
||||
|
||||
def __init__(self, server_packages, target_package):
|
||||
|
@ -13,7 +12,7 @@ class DependencyResolver:
|
|||
self._target_package = target_package
|
||||
|
||||
def requirement_met(self, upgrade: bool, version: str = None) -> bool:
|
||||
exists = self._package_exists_on_server()
|
||||
exists = self._package_exists_on_server(self._target_package)
|
||||
return exists and (not upgrade or
|
||||
(version is not None and self.get_target_server_version() != "" and
|
||||
LooseVersion(self.get_target_server_version()) >= LooseVersion(version)))
|
||||
|
@ -27,7 +26,7 @@ class DependencyResolver:
|
|||
def get_required_installs(self, target_requirements):
|
||||
required_packages = []
|
||||
for requirement in target_requirements:
|
||||
reqmet = any(package[0] == requirement.name for package in self._server_packages)
|
||||
reqmet = self._package_exists_on_server(requirement.name)
|
||||
|
||||
for spec in requirement.specs:
|
||||
reqmet = reqmet & self._check_if_installed_package_meets_spec(
|
||||
|
@ -37,8 +36,10 @@ class DependencyResolver:
|
|||
required_packages.append(self.clean_requirement_name(requirement.name))
|
||||
return required_packages
|
||||
|
||||
def _package_exists_on_server(self):
|
||||
return any([serverpkg[0].lower() == self._target_package.lower() for serverpkg in self._server_packages])
|
||||
def _package_exists_on_server(self, pkgname):
|
||||
return any([self.clean_requirement_name(pkgname.lower()) ==
|
||||
self.clean_requirement_name(serverpkg[0].lower())
|
||||
for serverpkg in self._server_packages])
|
||||
|
||||
@staticmethod
|
||||
def clean_requirement_name(reqname: str):
|
||||
|
@ -49,7 +50,10 @@ class DependencyResolver:
|
|||
op_str = spec[0]
|
||||
req_version = spec[1]
|
||||
|
||||
installed_package_name_and_version = [package for package in package_tuples if package[0] == name]
|
||||
installed_package_name_and_version = [package for package in package_tuples \
|
||||
if DependencyResolver.clean_requirement_name(name.lower()) == \
|
||||
DependencyResolver.clean_requirement_name(package[0].lower())]
|
||||
|
||||
if not installed_package_name_and_version:
|
||||
return False
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import pip
|
||||
import warnings
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
pipversion = LooseVersion(pip.__version__ )
|
||||
|
||||
|
|
|
@ -1,11 +1,22 @@
|
|||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
def no_upgrade(pkgname: str, serverversion: str, pkgversion: str = ""):
|
||||
return """
|
||||
Package {pkgname} exists on server. Set upgrade to True to force upgrade.".format(pkgname))
|
||||
Package {pkgname} exists on server. Set upgrade to True in install to force upgrade.
|
||||
The version of {pkgname} you are trying to install is {pkgversion}.
|
||||
The version installed on the server is {serverversion}
|
||||
""".format(pkgname=pkgname, pkgversion=pkgversion, serverversion=serverversion)
|
||||
""".format(
|
||||
pkgname=pkgname,
|
||||
pkgversion=pkgversion,
|
||||
serverversion=serverversion
|
||||
)
|
||||
|
||||
|
||||
def install(pkgname: str, version: str, targetpackage: bool):
|
||||
target = "target package" if targetpackage else "required dependency"
|
||||
return "Installing {} {} version {}".format(target, pkgname, version)
|
||||
return "Installing {target} {pkgname} version {version}".format(
|
||||
target=target,
|
||||
pkgname=pkgname,
|
||||
version=version
|
||||
)
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
import sys
|
||||
import io
|
||||
|
||||
|
||||
class OutputCapture(io.StringIO):
|
||||
|
||||
def write(self, txt):
|
||||
sys.__stdout__.write(txt)
|
||||
super().write(txt)
|
|
@ -1,3 +1,8 @@
|
|||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pyodbc
|
||||
|
||||
from sqlmlutils.sqlbuilder import SQLBuilder
|
||||
from sqlmlutils.packagemanagement.scope import Scope
|
||||
|
||||
|
@ -7,25 +12,60 @@ class CreateLibraryBuilder(SQLBuilder):
|
|||
def __init__(self, pkg_name: str, pkg_filename: str, scope: Scope):
|
||||
self._name = clean_library_name(pkg_name)
|
||||
self._filename = pkg_filename
|
||||
self._has_params = True
|
||||
self._scope = scope
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
with open(self._filename, "rb") as f:
|
||||
pkgdatastr = "0x" + f.read().hex()
|
||||
package_bits = f.read()
|
||||
pkgdatastr = pyodbc.Binary(package_bits)
|
||||
return pkgdatastr
|
||||
|
||||
installcheckscript = """
|
||||
@property
|
||||
def base_script(self) -> str:
|
||||
sqlpkgname = self._name
|
||||
authorization = _get_authorization(self._scope)
|
||||
dummy_spees = _get_dummy_spees()
|
||||
|
||||
return """
|
||||
set NOCOUNT on
|
||||
-- Drop the library if it exists
|
||||
BEGIN TRY
|
||||
DROP EXTERNAL LIBRARY [{sqlpkgname}] {authorization}
|
||||
END TRY
|
||||
BEGIN CATCH
|
||||
END CATCH
|
||||
|
||||
-- Create the library
|
||||
CREATE EXTERNAL LIBRARY [{sqlpkgname}] {authorization}
|
||||
FROM (CONTENT = ?) WITH (LANGUAGE = 'Python');
|
||||
|
||||
-- Dummy SPEES
|
||||
{dummy_spees}
|
||||
""".format(
|
||||
sqlpkgname=sqlpkgname,
|
||||
authorization=authorization,
|
||||
dummy_spees=dummy_spees
|
||||
)
|
||||
|
||||
|
||||
class CheckLibraryBuilder(SQLBuilder):
|
||||
|
||||
def __init__(self, pkg_name: str, scope: Scope):
|
||||
self._name = clean_library_name(pkg_name)
|
||||
self._scope = scope
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return """
|
||||
import os
|
||||
import re
|
||||
_ENV_NAME_USER_PATH = "MRS_EXTLIB_USER_PATH"
|
||||
_ENV_NAME_SHARED_PATH = "MRS_EXTLIB_SHARED_PATH"
|
||||
|
||||
|
||||
def _is_dist_info_file(name, file):
|
||||
return re.match(name + r"-.*egg", file) or re.match(name + r"-.*dist-info", file)
|
||||
|
||||
|
||||
def _is_package_match(package_name, file):
|
||||
package_name = package_name.lower()
|
||||
file = file.lower()
|
||||
|
@ -49,46 +89,26 @@ def package_exists_in_scope(sql_package_name: str, scope=None) -> bool:
|
|||
package_files = package_files_in_scope(scope)
|
||||
return any([_is_package_match(sql_package_name, package_file) for package_file in package_files])
|
||||
|
||||
|
||||
assert package_exists_in_scope("{sqlpkgname}", "{scopestr}")
|
||||
""".format(sqlpkgname=self._name, scopestr=self._scope._name)
|
||||
|
||||
return pkgdatastr, installcheckscript
|
||||
# Check that the package exists in scope.
|
||||
# For some reason this check works but there is a bug in pyODBC when asserting this is True.
|
||||
assert package_exists_in_scope("{name}", "{scope}") != False
|
||||
""".format(name=self._name, scope=self._scope._name)
|
||||
|
||||
@property
|
||||
def base_script(self) -> str:
|
||||
return """
|
||||
-- Drop the library if it exists
|
||||
BEGIN TRY
|
||||
DROP EXTERNAL LIBRARY [{sqlpkgname}] {authorization}
|
||||
END TRY
|
||||
BEGIN CATCH
|
||||
END CATCH
|
||||
|
||||
-- Parameter bind the package data
|
||||
DECLARE @content varbinary(MAX) = convert(varbinary(MAX), %s, 1);
|
||||
|
||||
-- Create the library
|
||||
CREATE EXTERNAL LIBRARY [{sqlpkgname}] {authorization}
|
||||
FROM (CONTENT = @content) WITH (LANGUAGE = 'Python');
|
||||
|
||||
-- Dummy SPEES
|
||||
{dummy_spees}
|
||||
|
||||
return """
|
||||
-- Check to make sure the package was installed
|
||||
BEGIN TRY
|
||||
exec sp_execute_external_script
|
||||
@language = N'Python',
|
||||
@script = %s
|
||||
@script = ?
|
||||
print('Package successfully installed.')
|
||||
END TRY
|
||||
BEGIN CATCH
|
||||
print('Package installation failed.');
|
||||
THROW;
|
||||
END CATCH
|
||||
""".format(sqlpkgname=self._name,
|
||||
authorization=_get_authorization(self._scope),
|
||||
dummy_spees=_get_dummy_spees())
|
||||
"""
|
||||
|
||||
|
||||
class DropLibraryBuilder(SQLBuilder):
|
||||
|
@ -100,10 +120,14 @@ class DropLibraryBuilder(SQLBuilder):
|
|||
@property
|
||||
def base_script(self) -> str:
|
||||
return """
|
||||
DROP EXTERNAL LIBRARY [{}] {authorization}
|
||||
DROP EXTERNAL LIBRARY [{name}] {auth}
|
||||
|
||||
{dummy_spees}
|
||||
""".format(self._name, authorization=_get_authorization(self._scope), dummy_spees=_get_dummy_spees())
|
||||
""".format(
|
||||
name=self._name,
|
||||
auth=_get_authorization(self._scope),
|
||||
dummy_spees=_get_dummy_spees()
|
||||
)
|
||||
|
||||
|
||||
def clean_library_name(pkgname: str):
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import re
|
||||
import requirements
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import pkginfo
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import pkginfo
|
||||
import re
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
class Scope:
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from sqlmlutils.packagemanagement.scope import Scope
|
||||
import os
|
||||
import re
|
||||
|
||||
from sqlmlutils.packagemanagement.scope import Scope
|
||||
|
||||
_ENV_NAME_USER_PATH = "MRS_EXTLIB_USER_PATH"
|
||||
_ENV_NAME_SHARED_PATH = "MRS_EXTLIB_SHARED_PATH"
|
||||
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
from sqlmlutils import ConnectionInfo, SQLPythonExecutor
|
||||
from sqlmlutils.sqlqueryexecutor import execute_query, SQLTransaction
|
||||
from sqlmlutils.packagemanagement.packagesqlbuilder import clean_library_name
|
||||
from sqlmlutils.packagemanagement import servermethods
|
||||
from sqlmlutils.sqlqueryexecutor import SQLQueryExecutor
|
||||
from sqlmlutils.packagemanagement import messages, servermethods
|
||||
from sqlmlutils.packagemanagement.dependencyresolver import DependencyResolver
|
||||
from sqlmlutils.packagemanagement.packagesqlbuilder import CreateLibraryBuilder, CheckLibraryBuilder, \
|
||||
DropLibraryBuilder, clean_library_name
|
||||
from sqlmlutils.packagemanagement.pipdownloader import PipDownloader
|
||||
from sqlmlutils.packagemanagement.scope import Scope
|
||||
from sqlmlutils.packagemanagement import messages
|
||||
from sqlmlutils.packagemanagement.pkgutils import get_package_name_from_file, get_package_version_from_file
|
||||
from sqlmlutils.packagemanagement.packagesqlbuilder import CreateLibraryBuilder, DropLibraryBuilder
|
||||
from sqlmlutils.packagemanagement.scope import Scope
|
||||
from sqlmlutils.sqlqueryexecutor import execute_query, SQLQueryExecutor
|
||||
|
||||
|
||||
class SQLPackageManager:
|
||||
|
@ -90,7 +88,7 @@ class SQLPackageManager:
|
|||
if scope is None:
|
||||
scope = self._get_default_scope()
|
||||
|
||||
print("Uninstalling " + package_name + " only, not dependencies")
|
||||
print("Uninstalling {package_name} only, not dependencies".format(package_name=package_name))
|
||||
self._drop_sql_package(package_name, scope, out_file)
|
||||
|
||||
def list(self):
|
||||
|
@ -113,19 +111,21 @@ class SQLPackageManager:
|
|||
SELECT @currentUser = "
|
||||
|
||||
if has_user:
|
||||
query += "%s;\n"
|
||||
query += "?;\n"
|
||||
else:
|
||||
query += "CURRENT_USER;\n"
|
||||
|
||||
scope_num = 1 if scope == Scope.private_scope() else 0
|
||||
|
||||
query += "SELECT @principalId = USER_ID(@currentUser); \
|
||||
SELECT name, language, scope \
|
||||
FROM sys.external_libraries AS elib \
|
||||
WHERE elib.principal_id=@principalId \
|
||||
AND elib.language='Python' AND elib.scope={0} \
|
||||
ORDER BY elib.name ASC;".format(1 if scope == Scope.private_scope() else 0)
|
||||
AND elib.language='Python' AND elib.scope={scope_num} \
|
||||
ORDER BY elib.name ASC;".format(scope_num=scope_num)
|
||||
return self._pyexecutor.execute_sql_query(query, owner)
|
||||
|
||||
def _drop_sql_package(self, sql_package_name: str, scope: Scope, out_file: str):
|
||||
def _drop_sql_package(self, sql_package_name: str, scope: Scope, out_file: str = None):
|
||||
builder = DropLibraryBuilder(sql_package_name=sql_package_name, scope=scope)
|
||||
execute_query(builder, self._connection_info, out_file)
|
||||
|
||||
|
@ -170,28 +170,30 @@ class SQLPackageManager:
|
|||
# Resolve which package dependencies need to be installed or upgraded on server.
|
||||
required_installs = resolver.get_required_installs(target_package_requirements)
|
||||
dependencies_to_install = self._get_required_files_to_install(requirements_downloaded, required_installs)
|
||||
|
||||
self._install_many(target_package_file, dependencies_to_install, scope, out_file=out_file)
|
||||
|
||||
def _install_many(self, target_package_file: str, dependency_files, scope: Scope, out_file:str=None):
|
||||
target_name = get_package_name_from_file(target_package_file)
|
||||
|
||||
with SQLQueryExecutor(connection=self._connection_info) as sqlexecutor:
|
||||
transaction = SQLTransaction(sqlexecutor, clean_library_name(target_name) + "InstallTransaction")
|
||||
transaction.begin()
|
||||
sqlexecutor._cnxn.autocommit = False
|
||||
try:
|
||||
print("Installing dependencies...")
|
||||
for pkgfile in dependency_files:
|
||||
self._install_single(sqlexecutor, pkgfile, scope, out_file=out_file)
|
||||
|
||||
print("Done with dependencies, installing main package...")
|
||||
self._install_single(sqlexecutor, target_package_file, scope, True, out_file=out_file)
|
||||
transaction.commit()
|
||||
sqlexecutor._cnxn.commit()
|
||||
except Exception as e:
|
||||
transaction.rollback()
|
||||
sqlexecutor._cnxn.rollback()
|
||||
raise RuntimeError("Package installation failed, installed dependencies were rolled back.") from e
|
||||
|
||||
@staticmethod
|
||||
def _install_single(sqlexecutor: SQLQueryExecutor, package_file: str, scope: Scope, is_target=False, out_file: str=None):
|
||||
name = get_package_name_from_file(package_file)
|
||||
version = get_package_version_from_file(package_file)
|
||||
name = str(get_package_name_from_file(package_file))
|
||||
version = str(get_package_version_from_file(package_file))
|
||||
print("Installing {name} version: {version}".format(name=name, version=version))
|
||||
|
||||
with tempfile.TemporaryDirectory() as temporary_directory:
|
||||
prezip = os.path.join(temporary_directory, name + "PREZIP.zip")
|
||||
|
@ -199,7 +201,9 @@ class SQLPackageManager:
|
|||
zipf.write(package_file, os.path.basename(package_file))
|
||||
|
||||
builder = CreateLibraryBuilder(pkg_name=name, pkg_filename=prezip, scope=scope)
|
||||
sqlexecutor.execute(builder, out_file=out_file, getResults=False)
|
||||
sqlexecutor.execute(builder, out_file=out_file)
|
||||
builder = CheckLibraryBuilder(pkg_name=name, scope=scope)
|
||||
sqlexecutor.execute(builder, out_file=out_file)
|
||||
|
||||
@staticmethod
|
||||
def _get_required_files_to_install(pkgfiles, requirements):
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Callable, List
|
||||
import abc
|
||||
import dill
|
||||
import inspect
|
||||
import textwrap
|
||||
from pandas import DataFrame
|
||||
import warnings
|
||||
|
||||
RETURN_COLUMN_NAME = "return_val"
|
||||
|
||||
from pandas import DataFrame
|
||||
from typing import Callable, List
|
||||
|
||||
"""
|
||||
_SQLBuilder implementations are used to generate SQL scripts to execute_function_in_sql Python functions and
|
||||
|
@ -25,6 +22,9 @@ All _SQLBuilder classes implement a base_script property. This is the text of th
|
|||
return values in their params property.
|
||||
"""
|
||||
|
||||
RETURN_COLUMN_NAME = "return_val"
|
||||
STDOUT_COLUMN_NAME = "_stdout_"
|
||||
STDERR_COLUMN_NAME = "_stderr_"
|
||||
|
||||
class SQLBuilder:
|
||||
|
||||
|
@ -43,9 +43,12 @@ class SpeesBuilder(SQLBuilder):
|
|||
|
||||
"""
|
||||
|
||||
_WITH_RESULTS_TEXT = "with result sets(({stdout} varchar(MAX), {stderr} varchar(MAX)))".format(
|
||||
stdout=STDOUT_COLUMN_NAME, stderr=STDERR_COLUMN_NAME)
|
||||
|
||||
def __init__(self,
|
||||
script: str,
|
||||
with_results_text: str = "",
|
||||
with_results_text: str = _WITH_RESULTS_TEXT,
|
||||
input_data_query: str = "",
|
||||
script_parameters_text: str = ""):
|
||||
"""Instantiate a _SpeesBuilder object.
|
||||
|
@ -55,7 +58,7 @@ class SpeesBuilder(SQLBuilder):
|
|||
:param input_data_query: maps to @input_data_1 SQL query parameter
|
||||
:param script_parameters_text: maps to @params SQL query parameter
|
||||
"""
|
||||
self._script = script
|
||||
self._script = self.modify_script(script)
|
||||
self._input_data_query = input_data_query
|
||||
self._script_parameters_text = script_parameters_text
|
||||
self._with_results_text = with_results_text
|
||||
|
@ -65,17 +68,37 @@ class SpeesBuilder(SQLBuilder):
|
|||
return """
|
||||
exec sp_execute_external_script
|
||||
@language = N'Python',
|
||||
@script = %s,
|
||||
@input_data_1 = %s
|
||||
{script_parameters}
|
||||
@script = ?,
|
||||
@input_data_1 = ?
|
||||
{script_parameters_text}
|
||||
{with_results_text}
|
||||
""".format(script_parameters=self._script_parameters_text,
|
||||
with_results_text=self._with_results_text)
|
||||
""".format(script_parameters_text=self._script_parameters_text,
|
||||
with_results_text=self._with_results_text)
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self._script, self._input_data_query
|
||||
|
||||
def modify_script(self, script):
|
||||
return """
|
||||
import sys
|
||||
from io import StringIO
|
||||
from pandas import DataFrame
|
||||
|
||||
_temp_out = StringIO()
|
||||
_temp_err = StringIO()
|
||||
|
||||
sys.stdout = _temp_out
|
||||
sys.stderr = _temp_err
|
||||
OutputDataSet = DataFrame()
|
||||
|
||||
{script}
|
||||
|
||||
OutputDataSet["{stdout}"] = [_temp_out.getvalue()]
|
||||
OutputDataSet["{stderr}"] = [_temp_err.getvalue()]
|
||||
""".format(script=script,
|
||||
stdout=STDOUT_COLUMN_NAME,
|
||||
stderr=STDERR_COLUMN_NAME)
|
||||
|
||||
class SpeesBuilderFromFunction(SpeesBuilder):
|
||||
|
||||
|
@ -83,7 +106,11 @@ class SpeesBuilderFromFunction(SpeesBuilder):
|
|||
_SpeesBuilderFromFunction objects are used to generate SPEES queries based on a function and given arguments.
|
||||
"""
|
||||
|
||||
_WITH_RESULTS_TEXT = "with result sets((return_val varchar(MAX)))"
|
||||
_WITH_RESULTS_TEXT = "with result sets(({returncol} varchar(MAX), {stdout} varchar(MAX), {stderr} varchar(MAX)))".format(
|
||||
returncol=RETURN_COLUMN_NAME,
|
||||
stdout=STDOUT_COLUMN_NAME,
|
||||
stderr=STDERR_COLUMN_NAME
|
||||
)
|
||||
|
||||
def __init__(self, func: Callable, input_data_query: str = "", *args, **kwargs):
|
||||
"""Instantiate a _SpeesBuilderFromFunction object.
|
||||
|
@ -112,11 +139,12 @@ class SpeesBuilderFromFunction(SpeesBuilder):
|
|||
args_dill = dill.dumps(kwargs).hex()
|
||||
pos_args_dill = dill.dumps(args).hex()
|
||||
function_name = func.__name__
|
||||
func_arguments=SpeesBuilderFromFunction._func_arguments(with_inputdf)
|
||||
|
||||
return """
|
||||
{user_function_text}
|
||||
{function_text}
|
||||
|
||||
import dill
|
||||
import pandas as pd
|
||||
|
||||
# serialized keyword arguments
|
||||
args_dill = bytes.fromhex("{args_dill}")
|
||||
|
@ -127,20 +155,21 @@ args = dill.loads(args_dill)
|
|||
pos_args = dill.loads(pos_args_dill)
|
||||
|
||||
# user function name
|
||||
func = {user_function_name}
|
||||
func = {function_name}
|
||||
|
||||
# call user function with serialized arguments
|
||||
return_val = func{func_arguments}
|
||||
{returncol} = func{func_arguments}
|
||||
|
||||
return_frame = pd.DataFrame()
|
||||
# serialize results of user function and put in DataFrame for return through SQL Satellite channel
|
||||
return_frame["return_val"] = [dill.dumps(return_val).hex()]
|
||||
OutputDataSet = return_frame
|
||||
""".format(user_function_text=function_text,
|
||||
args_dill=args_dill,
|
||||
pos_args_dill=pos_args_dill,
|
||||
user_function_name=function_name,
|
||||
func_arguments=SpeesBuilderFromFunction._func_arguments(with_inputdf))
|
||||
OutputDataSet["{returncol}"] = [dill.dumps({returncol}).hex()]
|
||||
""".format(
|
||||
function_text=function_text,
|
||||
args_dill=args_dill,
|
||||
pos_args_dill=pos_args_dill,
|
||||
function_name=function_name,
|
||||
returncol=RETURN_COLUMN_NAME,
|
||||
func_arguments=func_arguments
|
||||
)
|
||||
|
||||
# Call syntax of the user function
|
||||
# When with_inputdf is true, the user function will always take the "InputDataSet" magic variable as its first
|
||||
|
@ -169,6 +198,10 @@ class StoredProcedureBuilder(SQLBuilder):
|
|||
input_params = {}
|
||||
if output_params is None:
|
||||
output_params = {}
|
||||
|
||||
output_params[STDOUT_COLUMN_NAME] = str
|
||||
output_params[STDERR_COLUMN_NAME] = str
|
||||
|
||||
self._script = script
|
||||
self._name = name
|
||||
self._input_params = input_params
|
||||
|
@ -191,19 +224,30 @@ class StoredProcedureBuilder(SQLBuilder):
|
|||
|
||||
return """
|
||||
CREATE PROCEDURE {name}
|
||||
{parameter_declarations}
|
||||
{param_declarations}
|
||||
AS
|
||||
SET NOCOUNT ON;
|
||||
EXEC sp_execute_external_script
|
||||
@language = N'Python',
|
||||
@script = %s
|
||||
{script_parameters}
|
||||
""".format(name=self._name,
|
||||
parameter_declarations=self._param_declarations,
|
||||
script_parameters=self._script_parameter_text)
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self._script
|
||||
@script = N'
|
||||
from io import StringIO
|
||||
import sys
|
||||
_stdout = StringIO()
|
||||
_stderr = StringIO()
|
||||
sys.stdout = _stdout
|
||||
sys.stderr = _stderr
|
||||
{script}
|
||||
{stdout} = _stdout.getvalue()
|
||||
{stderr} = _stderr.getvalue()'
|
||||
{script_parameter_text}
|
||||
""".format(
|
||||
name=self._name,
|
||||
param_declarations=self._param_declarations,
|
||||
script=self._script,
|
||||
stdout=STDOUT_COLUMN_NAME,
|
||||
stderr=STDERR_COLUMN_NAME,
|
||||
script_parameter_text=self._script_parameter_text
|
||||
)
|
||||
|
||||
def script_parameter_text(self, in_names: List[str], in_types: dict, out_names: List[str], out_types: dict) -> str:
|
||||
if not in_names and not out_names:
|
||||
|
@ -228,16 +272,17 @@ EXEC sp_execute_external_script
|
|||
break
|
||||
|
||||
if in_data_name != "":
|
||||
script_params += ",\n" + StoredProcedureBuilderFromFunction.get_input_data_set(in_data_name)
|
||||
script_params += ",\n" + self.get_input_data_set(in_data_name)
|
||||
|
||||
if out_data_name != "":
|
||||
script_params += ",\n" + StoredProcedureBuilderFromFunction.get_output_data_set(out_data_name)
|
||||
script_params += ",\n" + self.get_output_data_set(out_data_name)
|
||||
|
||||
if len(in_names) > 0:
|
||||
if len(in_names) > 0 or len(out_names) > 0:
|
||||
script_params += ","
|
||||
|
||||
in_params_declaration = out_params_declaration = ""
|
||||
in_params_passing = out_params_passing = ""
|
||||
|
||||
if len(in_names) > 0:
|
||||
in_params_declaration = self.get_declarations(in_names, in_types)
|
||||
in_params_passing = self.get_params_passing(in_names)
|
||||
|
@ -250,9 +295,10 @@ EXEC sp_execute_external_script
|
|||
params_passing = self.combine_in_out(in_params_passing, out_params_passing)
|
||||
|
||||
if params_declaration != "":
|
||||
script_params += "\n@params = N'{params_declarations}',\n {params_passing}".format(
|
||||
params_declarations=params_declaration,
|
||||
params_passing=params_passing)
|
||||
script_params += "\n@params = N'{params_declaration}',\n {params_passing}".format(
|
||||
params_declaration=params_declaration,
|
||||
params_passing=params_passing
|
||||
)
|
||||
|
||||
return script_params
|
||||
|
||||
|
@ -274,10 +320,11 @@ EXEC sp_execute_external_script
|
|||
|
||||
@staticmethod
|
||||
def get_declarations(names_of_args: List[str], type_annotations: dict, outputs: bool = False):
|
||||
return ",\n ".join(["@" + name + " {sqltype}{output}".format(
|
||||
sqltype=StoredProcedureBuilder.to_sql_type(type_annotations.get(name, None)),
|
||||
output=" OUTPUT" if outputs else ""
|
||||
) for name in names_of_args])
|
||||
return ",\n ".join(["@{name} {sqltype}{output}".format(
|
||||
name = name,
|
||||
sqltype = StoredProcedureBuilder.to_sql_type(type_annotations.get(name, None)),
|
||||
output = " OUTPUT" if outputs else "")
|
||||
for name in names_of_args])
|
||||
|
||||
@staticmethod
|
||||
def to_sql_type(pytype):
|
||||
|
@ -294,36 +341,38 @@ EXEC sp_execute_external_script
|
|||
|
||||
@staticmethod
|
||||
def get_params_passing(names_of_args, outputs: bool = False):
|
||||
return ",\n ".join(["@" + name + " = " + "@" + name + "{output}".format(output=" OUTPUT" if outputs else "")
|
||||
for name in names_of_args])
|
||||
return ",\n ".join(["@{name} = @{name} {output}".format(
|
||||
name=name,
|
||||
output=" OUTPUT" if outputs else "")
|
||||
for name in names_of_args])
|
||||
|
||||
|
||||
class StoredProcedureBuilderFromFunction(StoredProcedureBuilder):
|
||||
|
||||
"""Build query text for stored procedures creation based on Python functions.
|
||||
|
||||
ex:
|
||||
ex:
|
||||
|
||||
name: "MyStoredProcedure"
|
||||
func:
|
||||
def foobar(arg1: str, arg2: str, arg3: str):
|
||||
print(arg1, arg2, arg3)
|
||||
name: "MyStoredProcedure"
|
||||
func:
|
||||
def foobar(arg1: str, arg2: str, arg3: str):
|
||||
print(arg1, arg2, arg3)
|
||||
|
||||
===========becomes===================
|
||||
===========becomes===================
|
||||
|
||||
create procedure MyStoredProcedure @arg1 varchar(MAX), @arg2 varchar(MAX), @arg3 varchar(MAX) as
|
||||
create procedure MyStoredProcedure @arg1 varchar(MAX), @arg2 varchar(MAX), @arg3 varchar(MAX) as
|
||||
|
||||
exec sp_execute_external_script
|
||||
@language = N'Python',
|
||||
@script=N'
|
||||
def foobar(arg1, arg2, arg3):
|
||||
print(arg1, arg2, arg3)
|
||||
foobar(arg1=arg1, arg2=arg2, arg3=arg3)
|
||||
',
|
||||
@params = N'@arg1 varchar(MAX), @arg2 varchar(MAX), @arg3 varchar(MAX)',
|
||||
@arg1 = @arg1,
|
||||
@arg2 = @arg2,
|
||||
@arg3 = @arg3
|
||||
exec sp_execute_external_script
|
||||
@language = N'Python',
|
||||
@script=N'
|
||||
def foobar(arg1, arg2, arg3):
|
||||
print(arg1, arg2, arg3)
|
||||
foobar(arg1=arg1, arg2=arg2, arg3=arg3)
|
||||
',
|
||||
@params = N'@arg1 varchar(MAX), @arg2 varchar(MAX), @arg3 varchar(MAX)',
|
||||
@arg1 = @arg1,
|
||||
@arg2 = @arg2,
|
||||
@arg3 = @arg3
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, func: Callable,
|
||||
|
@ -340,13 +389,18 @@ foobar(arg1=arg1, arg2=arg2, arg3=arg3)
|
|||
input_params = {}
|
||||
if output_params is None:
|
||||
output_params = {}
|
||||
|
||||
output_params[STDOUT_COLUMN_NAME] = str
|
||||
output_params[STDERR_COLUMN_NAME] = str
|
||||
|
||||
self._func = func
|
||||
self._name = name
|
||||
self._output_params = output_params
|
||||
|
||||
# Get function information
|
||||
function_text = textwrap.dedent(inspect.getsource(self._func))
|
||||
# Get function text and escape single quotes
|
||||
function_text = textwrap.dedent(inspect.getsource(self._func)).replace("'","''")
|
||||
|
||||
# Get function arguments and type annotations
|
||||
argspec = inspect.getfullargspec(self._func)
|
||||
names_of_input_args = argspec.args
|
||||
annotations = argspec.annotations
|
||||
|
@ -365,12 +419,12 @@ foobar(arg1=arg1, arg2=arg2, arg3=arg3)
|
|||
self._input_params = {}
|
||||
|
||||
names_of_output_args = list(self._output_params)
|
||||
|
||||
|
||||
if len(names_of_input_args) != len(self._input_params):
|
||||
raise ValueError("Number of argument annotations doesn't match the number of arguments!")
|
||||
if set(names_of_input_args) != set(self._input_params.keys()):
|
||||
raise ValueError("Names of arguments do not match the annotation keys!")
|
||||
|
||||
|
||||
calling_text = self.get_function_calling_text(self._func, names_of_input_args)
|
||||
|
||||
output_data_set = None
|
||||
|
@ -380,14 +434,18 @@ foobar(arg1=arg1, arg2=arg2, arg3=arg3)
|
|||
output_data_set = name
|
||||
break
|
||||
|
||||
ending = self.get_ending(self._output_params, output_data_set)
|
||||
# Creates the base python script to put in the SPEES query.
|
||||
# Arguments to function are passed by name into script using SPEES @params argument.
|
||||
self._script = """
|
||||
{function_text}
|
||||
{function_call_text}
|
||||
{calling_text}
|
||||
{ending}
|
||||
""".format(function_text=function_text, function_call_text=calling_text,
|
||||
ending=self.get_ending(self._output_params, output_data_set))
|
||||
""".format(
|
||||
function_text=function_text,
|
||||
calling_text=calling_text,
|
||||
ending=ending
|
||||
)
|
||||
|
||||
self._in_parameter_declarations = self.get_declarations(names_of_input_args, self._input_params)
|
||||
self._out_parameter_declarations = self.get_declarations(names_of_output_args, self._output_params,
|
||||
|
@ -404,55 +462,113 @@ foobar(arg1=arg1, arg2=arg2, arg3=arg3)
|
|||
def get_function_calling_text(func: Callable, names_of_args: List[str]):
|
||||
# For a function named foo with signature def foo(arg1, arg2, arg3)...
|
||||
# kwargs_text is 'arg1=arg1, arg2=arg2, arg3=arg3'
|
||||
kwargs_text = ", ".join("{}={}".format(name, name) for name in names_of_args)
|
||||
kwargs_text = ", ".join("{name}={name}".format(name=name) for name in names_of_args)
|
||||
|
||||
# returns 'foo(arg1=arg2, arg2=arg2, arg3=arg3)'
|
||||
return "result = " + func.__name__ + "({})".format(kwargs_text)
|
||||
return "result = {name}({kwargs})".format(name=func.__name__, kwargs=kwargs_text)
|
||||
|
||||
# Convert results to Output data frame and Output parameters
|
||||
def get_ending(self, output_params: dict, output_data_set: str):
|
||||
def get_ending(self, output_params: dict, output_data_set_name: str):
|
||||
out_df = output_data_set_name if output_data_set_name is not None else "OutputDataSet"
|
||||
res = """
|
||||
if type(result) == DataFrame:
|
||||
{result_val}""".format(result_val="{out_df} = result".format(out_df=output_data_set
|
||||
if output_data_set is not None else "OutputDataSet"))
|
||||
{out_df} = result
|
||||
""".format(out_df = out_df)
|
||||
|
||||
if len(output_params) > 0 or output_data_set is not None:
|
||||
trimmed_output_params = output_params.copy()
|
||||
trimmed_output_params.pop(STDOUT_COLUMN_NAME, None)
|
||||
trimmed_output_params.pop(STDERR_COLUMN_NAME, None)
|
||||
|
||||
if len(trimmed_output_params) > 0 or output_data_set_name is not None:
|
||||
output_params = self.get_output_params(trimmed_output_params) if len(trimmed_output_params) > 0 else "pass"
|
||||
res += """
|
||||
elif type(result) == dict:
|
||||
{output_params}
|
||||
elif result is not None:
|
||||
raise TypeError("Must return a DataFrame or dictionary with output parameters or None")
|
||||
""".format(output_params=self.get_output_params(output_params) if len(output_params) > 0 else "pass")
|
||||
""".format(output_params = output_params)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_output_params(output_params: dict):
|
||||
return "\n ".join(['{name} = result["{name}"]'.format(name=name) for name in list(output_params)])
|
||||
return "\n ".join(['{name} = result["{name}"]'.format(name=name)
|
||||
for name in list(output_params)])
|
||||
|
||||
|
||||
class ExecuteStoredProcedureBuilder(SQLBuilder):
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
def __init__(self, name: str, output_params: dict = None, **kwargs):
|
||||
self._name = name
|
||||
self._kwargs = kwargs
|
||||
self._output_params = output_params
|
||||
|
||||
# Execute the query: exec sproc @var1 = val1, @var2 = val2...
|
||||
# Does not work with output parameters
|
||||
@property
|
||||
def base_script(self) -> str:
|
||||
parameters = ", ".join(["@{name} = {value}".format(name=name, value=self.format_value(self._kwargs[name]))
|
||||
for name in self._kwargs])
|
||||
return """exec {} {}""".format(self._name, parameters)
|
||||
if self._output_params is not None:
|
||||
# Remove DataFrame from the output parameters, the DataFrame will be the OutputDataSet
|
||||
for name, py_type in list(self._output_params.items()):
|
||||
if py_type == DataFrame:
|
||||
del self._output_params[name]
|
||||
|
||||
parameters = " ".join(["@{name} = {value},".format(name=name, value=self.format_value(self._kwargs[name]))
|
||||
for name in self._kwargs])
|
||||
|
||||
retval = """
|
||||
DECLARE @{stdout} nvarchar(MAX),
|
||||
@{stderr} nvarchar(MAX)
|
||||
{output_declarations}
|
||||
|
||||
exec {sproc_name} {parameters}
|
||||
@{stdout} = @{stdout} OUTPUT,
|
||||
@{stderr} = @{stderr} OUTPUT
|
||||
{output_calls}
|
||||
|
||||
SELECT @{stdout} as {stdout},
|
||||
@{stderr} as {stderr}
|
||||
{output_selects}
|
||||
""".format(stdout=STDOUT_COLUMN_NAME,
|
||||
stderr=STDERR_COLUMN_NAME,
|
||||
output_declarations=self.output_declarations(self._output_params),
|
||||
sproc_name=self._name,
|
||||
parameters=parameters,
|
||||
output_calls=self.output_calls(self._output_params),
|
||||
output_selects=self.output_selects(self._output_params))
|
||||
return retval
|
||||
|
||||
@staticmethod
|
||||
def format_value(value) -> str:
|
||||
if isinstance(value, str):
|
||||
return "'{}'".format(value)
|
||||
return "'{value}'".format(value=value)
|
||||
elif isinstance(value, int) or isinstance(value, float):
|
||||
return str(value)
|
||||
elif isinstance(value, bool):
|
||||
return str(int(value))
|
||||
else:
|
||||
raise ValueError("Parameter type {} not supported.".format(str(type(value))))
|
||||
raise ValueError("Parameter type {value_type} not supported.".format(value_type = str(type(value))))
|
||||
|
||||
def output_declarations(self, output_params):
|
||||
retval = ""
|
||||
if output_params is not None and len(output_params) > 0:
|
||||
retval += "".join([", @{name} {type}".format(name=name,
|
||||
type=StoredProcedureBuilderFromFunction.to_sql_type(output_params[name]))
|
||||
for name in output_params])
|
||||
return retval
|
||||
|
||||
def output_calls(self, output_params):
|
||||
retval = ""
|
||||
if output_params is not None and len(output_params) > 0:
|
||||
retval += "".join([", @{name} = @{name} OUTPUT".format(name=name)
|
||||
for name in output_params])
|
||||
return retval
|
||||
|
||||
def output_selects(self, output_params):
|
||||
retval = ""
|
||||
if output_params is not None and len(output_params) > 0:
|
||||
retval += "".join([", @{name} as {name}".format(name=name)
|
||||
for name in output_params])
|
||||
return retval
|
||||
|
||||
|
||||
class DropStoredProcedureBuilder(SQLBuilder):
|
||||
|
@ -462,6 +578,4 @@ class DropStoredProcedureBuilder(SQLBuilder):
|
|||
|
||||
@property
|
||||
def base_script(self) -> str:
|
||||
return """
|
||||
drop procedure {}
|
||||
""".format(self._name)
|
||||
return "drop procedure {name}".format(name=self._name)
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Callable
|
||||
import dill
|
||||
import sys
|
||||
|
||||
from typing import Callable
|
||||
from pandas import DataFrame
|
||||
|
||||
from .connectioninfo import ConnectionInfo
|
||||
from .sqlqueryexecutor import execute_query, execute_raw_query
|
||||
from .sqlbuilder import SpeesBuilder, SpeesBuilderFromFunction, StoredProcedureBuilder, \
|
||||
ExecuteStoredProcedureBuilder, DropStoredProcedureBuilder
|
||||
from .sqlbuilder import StoredProcedureBuilderFromFunction, RETURN_COLUMN_NAME
|
||||
from .sqlbuilder import StoredProcedureBuilderFromFunction
|
||||
from .sqlbuilder import RETURN_COLUMN_NAME, STDOUT_COLUMN_NAME, STDERR_COLUMN_NAME
|
||||
|
||||
|
||||
class SQLPythonExecutor:
|
||||
|
@ -44,8 +47,13 @@ class SQLPythonExecutor:
|
|||
>>> print(ret)
|
||||
[0.28366218546322625, 0.28366218546322625]
|
||||
"""
|
||||
rows = execute_query(SpeesBuilderFromFunction(func, input_data_query, *args, **kwargs), self._connection_info)
|
||||
return self._get_results(rows)
|
||||
df, _ = execute_query(SpeesBuilderFromFunction(func, input_data_query, *args, **kwargs), self._connection_info)
|
||||
results, output, error = self._get_results(df)
|
||||
if output is not None:
|
||||
print(output)
|
||||
if error is not None:
|
||||
print(error, file=sys.stderr)
|
||||
return results
|
||||
|
||||
def execute_script_in_sql(self,
|
||||
path_to_script: str,
|
||||
|
@ -61,7 +69,6 @@ class SQLPythonExecutor:
|
|||
try:
|
||||
with open(path_to_script, 'r') as script_file:
|
||||
content = script_file.read()
|
||||
print("File does exist, using " + path_to_script)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError("File does not exist!")
|
||||
execute_query(SpeesBuilder(content, input_data_query=input_data_query), connection=self._connection_info)
|
||||
|
@ -74,16 +81,7 @@ class SQLPythonExecutor:
|
|||
:param sql_query: the sql query to execute in the server
|
||||
:return: table returned by the sql_query
|
||||
"""
|
||||
rows = execute_raw_query(conn=self._connection_info, query=sql_query, params=params)
|
||||
df = DataFrame(rows)
|
||||
|
||||
# _mssql's execute_query() returns duplicate keys for indexing, we remove them because they are extraneous
|
||||
for i in range(len(df.columns)):
|
||||
try:
|
||||
del df[i]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
df, _ = execute_raw_query(conn=self._connection_info, query=sql_query, params=params)
|
||||
return df
|
||||
|
||||
def create_sproc_from_function(self, name: str, func: Callable,
|
||||
|
@ -124,9 +122,16 @@ class SQLPythonExecutor:
|
|||
input_params = {}
|
||||
if output_params is None:
|
||||
output_params = {}
|
||||
|
||||
# We modify input_params/output_params because we add stdout and stderr as params.
|
||||
# We copy here to avoid modifying the underlying contents.
|
||||
#
|
||||
in_copy = input_params.copy() if input_params is not None else None
|
||||
out_copy = output_params.copy() if output_params is not None else None
|
||||
|
||||
# Save the stored procedure in database
|
||||
execute_query(StoredProcedureBuilderFromFunction(name, func,
|
||||
input_params, output_params), self._connection_info)
|
||||
execute_query(StoredProcedureBuilderFromFunction(name, func, in_copy, out_copy),
|
||||
self._connection_info)
|
||||
return True
|
||||
|
||||
def create_sproc_from_script(self, name: str, path_to_script: str,
|
||||
|
@ -162,8 +167,14 @@ class SQLPythonExecutor:
|
|||
except FileNotFoundError:
|
||||
raise FileNotFoundError("File does not exist!")
|
||||
|
||||
execute_query(StoredProcedureBuilder(name, content,
|
||||
input_params, output_params), self._connection_info)
|
||||
# We modify input_params/output_params because we add stdout and stderr as params.
|
||||
# We copy here to avoid modifying the underlying contents.
|
||||
#
|
||||
in_copy = input_params.copy() if input_params is not None else None
|
||||
out_copy = output_params.copy() if output_params is not None else None
|
||||
|
||||
execute_query(StoredProcedureBuilder(name, content, in_copy, out_copy),
|
||||
self._connection_info)
|
||||
return True
|
||||
|
||||
def check_sproc(self, name: str) -> bool:
|
||||
|
@ -180,20 +191,28 @@ class SQLPythonExecutor:
|
|||
:param name: name of stored procedure.
|
||||
:return: boolean whether the Stored Procedure exists in the database
|
||||
"""
|
||||
check_query = "SELECT OBJECT_ID (%s, N'P')"
|
||||
rows = execute_raw_query(conn=self._connection_info, query=check_query, params=name)
|
||||
return rows[0][0] is not None
|
||||
check_query = "SELECT OBJECT_ID (?, N'P')"
|
||||
rows = execute_raw_query(conn=self._connection_info, query=check_query, params=name)[0]
|
||||
return rows.loc[0].iloc[0] is not None
|
||||
|
||||
def execute_sproc(self, name: str, **kwargs) -> DataFrame:
|
||||
def execute_sproc(self, name: str, output_params: dict = None, **kwargs) -> DataFrame:
|
||||
"""Call a stored procedure on a SQL Server database.
|
||||
WARNING: Output parameters can be used when creating the stored procedure, but Stored Procedures with
|
||||
output parameters other than a single DataFrame cannot be executed with sqlmlutils
|
||||
|
||||
:param name: name of stored procedure.
|
||||
:param name: name of stored procedure
|
||||
:param output_params: output parameters (if any) for the stored procedure
|
||||
:param kwargs: keyword arguments to pass to stored procedure
|
||||
:return: DataFrame representing the output data set of the stored procedure (or empty)
|
||||
:return: tuple with a DataFrame representing the output data set of the stored procedure
|
||||
and a dictionary of output parameters
|
||||
"""
|
||||
return DataFrame(execute_query(ExecuteStoredProcedureBuilder(name, **kwargs), self._connection_info))
|
||||
|
||||
# We modify output_params because we add stdout and stderr as output params.
|
||||
# We copy here to avoid modifying the underlying contents.
|
||||
#
|
||||
out_copy = output_params.copy() if output_params is not None else None
|
||||
return execute_query(ExecuteStoredProcedureBuilder(name, out_copy, **kwargs),
|
||||
self._connection_info)
|
||||
|
||||
def drop_sproc(self, name: str):
|
||||
"""Drop a SQL Server stored procedure if it exists.
|
||||
|
@ -205,6 +224,8 @@ class SQLPythonExecutor:
|
|||
execute_query(DropStoredProcedureBuilder(name), self._connection_info)
|
||||
|
||||
@staticmethod
|
||||
def _get_results(rows):
|
||||
hexstring = rows[0][RETURN_COLUMN_NAME]
|
||||
return dill.loads(bytes.fromhex(hexstring))
|
||||
def _get_results(df : DataFrame):
|
||||
hexstring = df[RETURN_COLUMN_NAME][0]
|
||||
stdout_string = df[STDOUT_COLUMN_NAME][0]
|
||||
stderr_string = df[STDERR_COLUMN_NAME][0]
|
||||
return dill.loads(bytes.fromhex(hexstring)), stdout_string, stderr_string
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import _mssql
|
||||
import pyodbc
|
||||
import sys
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
from .connectioninfo import ConnectionInfo
|
||||
from .sqlbuilder import SQLBuilder
|
||||
from .sqlbuilder import STDOUT_COLUMN_NAME, STDERR_COLUMN_NAME
|
||||
|
||||
"""This module is used to actually execute sql queries. It uses the pymssql module under the hood.
|
||||
"""This module is used to actually execute sql queries. It uses the pyodbc module under the hood.
|
||||
|
||||
It is mostly setup to work with SQLBuilder objects as defined in sqlbuilder.
|
||||
"""
|
||||
|
@ -23,13 +28,7 @@ def execute_raw_query(conn: ConnectionInfo, query, params=()):
|
|||
with SQLQueryExecutor(connection=conn) as executor:
|
||||
return executor.execute_query(query, params)
|
||||
|
||||
|
||||
def _sql_msg_handler(msgstate, severity, srvname, procname, line, msgtext):
|
||||
print(msgtext.decode())
|
||||
|
||||
|
||||
class SQLQueryExecutor:
|
||||
|
||||
"""_SQLQueryExecutor objects keep a SQL connection open in order to execute_function_in_sql one or more queries.
|
||||
|
||||
This class implements the basic context manager paradigm.
|
||||
|
@ -38,83 +37,87 @@ class SQLQueryExecutor:
|
|||
def __init__(self, connection: ConnectionInfo):
|
||||
self._connection = connection
|
||||
|
||||
def execute(self, builder: SQLBuilder, out_file=None, getResults=True):
|
||||
def execute(self, builder: SQLBuilder, out_file=None):
|
||||
return self.execute_query(builder.base_script, builder.params, out_file=out_file)
|
||||
|
||||
def execute_query(self, query, params, out_file=None):
|
||||
df = DataFrame()
|
||||
output_params = None
|
||||
|
||||
try:
|
||||
if out_file is not None:
|
||||
with open(out_file,"a") as f:
|
||||
if builder.params is not None:
|
||||
script = builder.base_script.replace("%s", "N'%s'")
|
||||
f.write(script % builder.params)
|
||||
if params is not None:
|
||||
script = query.replace("?", "N'%s'")
|
||||
|
||||
# Convert bytearray to hex so user can run as a script
|
||||
#
|
||||
if type(params) is bytearray:
|
||||
params = str('0x' + params.hex())
|
||||
|
||||
f.write(script % params)
|
||||
else:
|
||||
f.write(builder.base_script)
|
||||
f.write(query)
|
||||
f.write("GO\n")
|
||||
f.write("-----------------------------")
|
||||
else:
|
||||
self._mssqlconn.set_msghandler(_sql_msg_handler)
|
||||
if getResults:
|
||||
self._mssqlconn.execute_query(builder.base_script, builder.params)
|
||||
return [row for row in self._mssqlconn]
|
||||
else:
|
||||
self._mssqlconn.execute_non_query(builder.base_script, builder.params)
|
||||
return []
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error in SQL Execution") from e
|
||||
|
||||
def execute_query(self, query, params, out_file=None):
|
||||
if out_file is not None:
|
||||
with open(out_file, "a") as f:
|
||||
else:
|
||||
if params is not None:
|
||||
script = query.replace("%s", "'%s'")
|
||||
f.write(script % params)
|
||||
self._cursor.execute(query, params)
|
||||
else:
|
||||
f.write(query)
|
||||
f.write("GO\n")
|
||||
f.write("-----------------------------")
|
||||
self._mssqlconn.execute_query(query, params)
|
||||
return [row for row in self._mssqlconn]
|
||||
self._cursor.execute(query)
|
||||
|
||||
# Get the first resultset (OutputDataSet)
|
||||
#
|
||||
if self._cursor.description is not None:
|
||||
column_names = [element[0] for element in self._cursor.description]
|
||||
rows = [tuple(t) for t in self._cursor.fetchall()]
|
||||
df = DataFrame(rows, columns=column_names)
|
||||
if STDOUT_COLUMN_NAME in column_names:
|
||||
self.extract_output(dict(zip(column_names, rows[0])))
|
||||
|
||||
# Get output parameters
|
||||
#
|
||||
while self._cursor.nextset():
|
||||
try:
|
||||
if self._cursor.description is not None:
|
||||
column_names = [element[0] for element in self._cursor.description]
|
||||
rows = [tuple(t) for t in self._cursor.fetchall()]
|
||||
output_params = dict(zip(column_names, rows[0]))
|
||||
|
||||
if STDOUT_COLUMN_NAME in column_names:
|
||||
self.extract_output(output_params)
|
||||
|
||||
except pyodbc.ProgrammingError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error in SQL Execution: " + str(e))
|
||||
|
||||
return df, output_params
|
||||
|
||||
def __enter__(self):
|
||||
if self._connection.port == "":
|
||||
self._mssqlconn = _mssql.connect(server=self._connection.server,
|
||||
user=self._connection.uid,
|
||||
password=self._connection.pwd,
|
||||
database=self._connection.database)
|
||||
else:
|
||||
self._mssqlconn = _mssql.connect(server=self._connection.server,
|
||||
port=self._connection.port,
|
||||
user=self._connection.uid,
|
||||
password=self._connection.pwd,
|
||||
database=self._connection.database)
|
||||
self._mssqlconn.set_msghandler(_sql_msg_handler)
|
||||
server=self._connection._server if self._connection._port == "" \
|
||||
else "{server},{port}".format(
|
||||
server=self._connection._server,
|
||||
port=self._connection._port
|
||||
)
|
||||
|
||||
self._cnxn = pyodbc.connect(driver=self._connection._driver,
|
||||
server=server,
|
||||
user=self._connection.uid,
|
||||
password=self._connection.pwd,
|
||||
database=self._connection.database,
|
||||
autocommit=True)
|
||||
self._cursor = self._cnxn.cursor()
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self._mssqlconn.close()
|
||||
|
||||
|
||||
class SQLTransaction:
|
||||
|
||||
def __init__(self, executor: SQLQueryExecutor, name):
|
||||
self._executor = executor
|
||||
self._name = name
|
||||
|
||||
def begin(self):
|
||||
query = """
|
||||
declare @transactionname varchar(MAX) = %s;
|
||||
begin tran @transactionname;
|
||||
"""
|
||||
self._executor.execute_query(query, self._name)
|
||||
|
||||
def rollback(self):
|
||||
query = """
|
||||
declare @transactionname varchar(MAX) = %s;
|
||||
rollback tran @transactionname;
|
||||
"""
|
||||
self._executor.execute_query(query, self._name)
|
||||
|
||||
def commit(self):
|
||||
query = """
|
||||
declare @transactionname varchar(MAX) = %s;
|
||||
commit tran @transactionname;
|
||||
"""
|
||||
self._executor.execute_query(query, self._name)
|
||||
self._cnxn.close()
|
||||
|
||||
def extract_output(self, output_params : dict):
|
||||
out = output_params.pop(STDOUT_COLUMN_NAME, None)
|
||||
err = output_params.pop(STDERR_COLUMN_NAME, None)
|
||||
if out is not None:
|
||||
print(out)
|
||||
if err is not None:
|
||||
print(err, file=sys.stderr)
|
|
@ -1,36 +0,0 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
from .connectioninfo import ConnectionInfo
|
||||
from .sqlqueryexecutor import execute_query
|
||||
from .sqlbuilder import ExecuteStoredProcedureBuilder, DropStoredProcedureBuilder
|
||||
|
||||
|
||||
class StoredProcedure:
|
||||
"""Represents a SQL Server stored procedure."""
|
||||
|
||||
def __init__(self, name: str, connection: ConnectionInfo):
|
||||
"""Instantiates a StoredProcedure. Not meant to be called directly, get handles to stored
|
||||
procedures using get_sproc.
|
||||
|
||||
:param name: name of stored procedure.
|
||||
"""
|
||||
self._name = name
|
||||
self._connection = connection
|
||||
|
||||
def call(self, **kwargs) -> DataFrame:
|
||||
"""Call a stored procedure on a SQL Server database.
|
||||
|
||||
:param kwargs: keyword arguments to pass to stored procedure
|
||||
:return: DataFrame representing the output data set of the stored procedure (or empty)
|
||||
"""
|
||||
return DataFrame(execute_query(ExecuteStoredProcedureBuilder(self._name, **kwargs), self._connection))
|
||||
|
||||
def drop(self):
|
||||
"""Drop a SQL Server stored procedure.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
execute_query(DropStoredProcedureBuilder(self._name), self._connection)
|
|
@ -1,5 +1,9 @@
|
|||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
from sqlmlutils import ConnectionInfo
|
||||
|
||||
from sqlmlutils import ConnectionInfo, Scope
|
||||
|
||||
driver = os.environ['DRIVER'] if 'DRIVER' in os.environ else "SQL Server"
|
||||
server = os.environ['SERVER'] if 'SERVER' in os.environ else "localhost"
|
||||
|
@ -7,6 +11,8 @@ database = os.environ['DATABASE'] if 'DATABASE' in os.environ else "AirlineTestD
|
|||
uid = os.environ['USER'] if 'USER' in os.environ else ""
|
||||
pwd = os.environ['PASSWORD'] if 'PASSWORD' in os.environ else ""
|
||||
|
||||
scope = Scope.public_scope() if uid == "" else Scope.private_scope()
|
||||
|
||||
|
||||
connection = ConnectionInfo(driver=driver,
|
||||
server=server,
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pytest
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
import io
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from sqlmlutils import SQLPythonExecutor
|
||||
from sqlmlutils import ConnectionInfo
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from pandas import DataFrame
|
||||
|
||||
from sqlmlutils import ConnectionInfo, SQLPythonExecutor
|
||||
from conftest import driver, server, database, uid, pwd
|
||||
|
||||
connection = ConnectionInfo(driver=driver,
|
||||
|
@ -20,10 +20,8 @@ connection = ConnectionInfo(driver=driver,
|
|||
current_dir = os.path.dirname(__file__)
|
||||
script_dir = os.path.join(current_dir, "scripts")
|
||||
|
||||
print(connection)
|
||||
sqlpy = SQLPythonExecutor(connection)
|
||||
|
||||
|
||||
def test_with_named_args():
|
||||
def func_with_args(arg1, arg2):
|
||||
print(arg1)
|
||||
|
@ -138,7 +136,7 @@ def test_with_variables():
|
|||
var_s = "World"
|
||||
sqlpy.execute_function_in_sql(func_with_variables, s=var_s)
|
||||
|
||||
assert "World" in output.getvalue()
|
||||
assert 'World' in output.getvalue()
|
||||
|
||||
|
||||
def test_execute_query():
|
||||
|
@ -149,7 +147,7 @@ def test_execute_query():
|
|||
|
||||
|
||||
def test_execute_script():
|
||||
path = os.path.join(script_dir, "test_script.py")
|
||||
path = os.path.join(script_dir, "exec_script.py")
|
||||
|
||||
output = io.StringIO()
|
||||
with redirect_stderr(output), redirect_stdout(output):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from sqlmlutils.sqlqueryexecutor import execute_raw_query
|
||||
|
@ -6,8 +6,10 @@ from sqlmlutils.sqlqueryexecutor import execute_raw_query
|
|||
|
||||
def _get_sql_package_table(connection):
|
||||
query = "select * from sys.external_libraries"
|
||||
return execute_raw_query(connection, query)
|
||||
out_df, outparams = execute_raw_query(connection, query)
|
||||
return out_df
|
||||
|
||||
|
||||
def _get_package_names_list(connection):
|
||||
return {dic['name']: dic['scope'] for dic in _get_sql_package_table(connection)}
|
||||
df = _get_sql_package_table(connection)
|
||||
return {x: y for x, y in zip(df['name'], df['scope'])}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import io
|
||||
|
@ -9,10 +9,8 @@ from contextlib import redirect_stdout
|
|||
|
||||
import pytest
|
||||
|
||||
import sqlmlutils
|
||||
from sqlmlutils import SQLPackageManager, SQLPythonExecutor
|
||||
from sqlmlutils import ConnectionInfo, SQLPackageManager, SQLPythonExecutor, Scope
|
||||
from package_helper_functions import _get_sql_package_table, _get_package_names_list
|
||||
from sqlmlutils.packagemanagement.scope import Scope
|
||||
from sqlmlutils.packagemanagement.pipdownloader import PipDownloader
|
||||
|
||||
from conftest import connection
|
||||
|
@ -54,16 +52,20 @@ def _drop(package_name: str, ddl_name: str):
|
|||
|
||||
|
||||
def _create(module_name: str, package_file: str, class_to_check: str, drop: bool = True):
|
||||
pyexecutor.execute_function_in_sql(check_package, package_name=module_name, exists=False)
|
||||
pkgmanager.install(package_file)
|
||||
pyexecutor.execute_function_in_sql(check_package, package_name=module_name, exists=True, class_to_check=class_to_check)
|
||||
if drop:
|
||||
_drop(package_name=module_name, ddl_name=module_name)
|
||||
try:
|
||||
pyexecutor.execute_function_in_sql(check_package, package_name=module_name, exists=False)
|
||||
pkgmanager.install(package_file)
|
||||
pyexecutor.execute_function_in_sql(check_package, package_name=module_name, exists=True, class_to_check=class_to_check)
|
||||
finally:
|
||||
if drop:
|
||||
_drop(package_name=module_name, ddl_name=module_name)
|
||||
|
||||
|
||||
def _remove_all_new_packages(manager):
|
||||
libs = {dic['external_library_id']: (dic['name'], dic['scope']) for dic in _get_sql_package_table(connection)}
|
||||
original_libs = {dic['external_library_id']: (dic['name'], dic['scope']) for dic in originals}
|
||||
df = _get_sql_package_table(connection)
|
||||
|
||||
libs = {df['external_library_id'][i]: (df['name'][i], df['scope'][i]) for i in range(len(df.index))}
|
||||
original_libs = {originals['external_library_id'][i]: (originals['name'][i], originals['scope'][i]) for i in range(len(originals.index))}
|
||||
|
||||
for lib in libs:
|
||||
pkg, sc = libs[lib]
|
||||
|
@ -81,8 +83,8 @@ def _remove_all_new_packages(manager):
|
|||
manager.uninstall(pkg, scope=Scope.public_scope())
|
||||
|
||||
|
||||
packages = ["absl-py==0.1.13", "astor==0.6.2", "bleach==1.5.0", "cryptography==2.2.2",
|
||||
"html5lib==1.0.1", "Markdown==2.6.11", "numpy==1.14.3", "termcolor==1.1.0", "webencodings==0.5.1"]
|
||||
packages = ["absl-py==0.1.13", "astor==0.8.1", "bleach==1.5.0",
|
||||
"html5lib==1.0.1", "Markdown==2.6.11", "termcolor==1.1.0", "webencodings==0.5.1"]
|
||||
|
||||
for package in packages:
|
||||
pipdownloader = PipDownloader(connection, path_to_packages, package)
|
||||
|
@ -102,12 +104,13 @@ def test_install_basic_zip_package_different_name():
|
|||
module_name = "testpackageA"
|
||||
|
||||
_remove_all_new_packages(pkgmanager)
|
||||
|
||||
_create(module_name=module_name, package_file=package, class_to_check="ClassA")
|
||||
|
||||
|
||||
def test_install_whl_files():
|
||||
packages = ["webencodings-0.5.1-py2.py3-none-any.whl", "html5lib-1.0.1-py2.py3-none-any.whl",
|
||||
"astor-0.6.2-py2.py3-none-any.whl"]
|
||||
"astor-0.8.1-py2.py3-none-any.whl"]
|
||||
module_names = ["webencodings", "html5lib", "astor"]
|
||||
classes_to_check = ["LABELS", "parse", "code_gen"]
|
||||
|
||||
|
@ -115,10 +118,7 @@ def test_install_whl_files():
|
|||
|
||||
for package, module, class_to_check in zip(packages, module_names, classes_to_check):
|
||||
full_package = os.path.join(path_to_packages, package)
|
||||
_create(module_name=module, package_file=full_package, class_to_check=class_to_check, drop=False)
|
||||
|
||||
for name in module_names:
|
||||
_drop(package_name=name, ddl_name=name)
|
||||
_create(module_name=module, package_file=full_package, class_to_check=class_to_check)
|
||||
|
||||
|
||||
def test_install_targz_files():
|
||||
|
@ -161,74 +161,38 @@ def test_package_already_exists_on_sql_table():
|
|||
|
||||
_remove_all_new_packages(pkgmanager)
|
||||
|
||||
# Install a downgraded version of the package first
|
||||
package = os.path.join(path_to_packages, "testpackageA-0.0.1.zip")
|
||||
pkgmanager.install(package)
|
||||
|
||||
def check_version():
|
||||
import pkg_resources
|
||||
return pkg_resources.get_distribution("testpackageA").version
|
||||
|
||||
version = pyexecutor.execute_function_in_sql(check_version)
|
||||
assert version == "0.0.1"
|
||||
|
||||
package = os.path.join(path_to_packages, "testpackageA-0.0.2.zip")
|
||||
|
||||
# Without upgrade
|
||||
output = io.StringIO()
|
||||
with redirect_stdout(output):
|
||||
pkgmanager.install(package, upgrade=False)
|
||||
assert "exists on server. Set upgrade to True to force upgrade." in output.getvalue()
|
||||
assert "exists on server. Set upgrade to True" in output.getvalue()
|
||||
|
||||
version = pyexecutor.execute_function_in_sql(check_version)
|
||||
assert version == "0.0.1"
|
||||
|
||||
# With upgrade
|
||||
package = os.path.join(path_to_packages, "testpackageA-0.0.2.zip")
|
||||
pkgmanager.install(package, upgrade=True)
|
||||
|
||||
def check_version():
|
||||
import testpackageA
|
||||
return testpackageA.__version__
|
||||
|
||||
version = pyexecutor.execute_function_in_sql(check_version)
|
||||
assert version == "0.0.2"
|
||||
|
||||
pkgmanager.uninstall("testpackageA")
|
||||
|
||||
|
||||
def test_upgrade_parameter():
|
||||
|
||||
_remove_all_new_packages(pkgmanager)
|
||||
|
||||
# Get sql packages
|
||||
originalsqlpkgs = _get_sql_package_table(connection)
|
||||
|
||||
pkg = os.path.join(path_to_packages, "cryptography-2.2.2-cp35-cp35m-win_amd64.whl")
|
||||
|
||||
output = io.StringIO()
|
||||
with redirect_stdout(output):
|
||||
pkgmanager.install(pkg, upgrade=False)
|
||||
assert "exists on server. Set upgrade to True to force upgrade." in output.getvalue()
|
||||
|
||||
# Assert no additional packages were installed
|
||||
|
||||
sqlpkgs = _get_sql_package_table(connection)
|
||||
assert len(sqlpkgs) == len(originalsqlpkgs)
|
||||
|
||||
#################
|
||||
|
||||
def check_version():
|
||||
import cryptography as cp
|
||||
return cp.__version__
|
||||
|
||||
oldversion = pyexecutor.execute_function_in_sql(check_version)
|
||||
|
||||
pkgmanager.install(pkg, upgrade=True)
|
||||
|
||||
sqlpkgs = _get_sql_package_table(connection)
|
||||
assert len(sqlpkgs) == len(originalsqlpkgs) + 2
|
||||
|
||||
version = pyexecutor.execute_function_in_sql(check_version)
|
||||
assert version == "2.2.2"
|
||||
assert version > oldversion
|
||||
|
||||
pkgmanager.uninstall("cryptography")
|
||||
pkgmanager.uninstall("asn1crypto")
|
||||
|
||||
sqlpkgs = _get_sql_package_table(connection)
|
||||
assert len(sqlpkgs) == len(originalsqlpkgs)
|
||||
|
||||
|
||||
# TODO: more tests for drop external library
|
||||
|
||||
def test_scope():
|
||||
|
||||
_remove_all_new_packages(pkgmanager)
|
||||
|
@ -239,7 +203,7 @@ def test_scope():
|
|||
import testpackageA
|
||||
return testpackageA.__file__
|
||||
|
||||
_revotesterconnection = sqlmlutils.ConnectionInfo(server="localhost",
|
||||
_revotesterconnection = ConnectionInfo(server="localhost",
|
||||
database="AirlineTestDB",
|
||||
uid="Tester",
|
||||
pwd="FakeT3sterPwd!")
|
||||
|
|
|
@ -1,29 +1,33 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sqlmlutils
|
||||
import io
|
||||
import os
|
||||
import pytest
|
||||
from sqlmlutils import SQLPythonExecutor, SQLPackageManager
|
||||
from sqlmlutils.packagemanagement.scope import Scope
|
||||
from package_helper_functions import _get_sql_package_table, _get_package_names_list
|
||||
import io
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
from package_helper_functions import _get_sql_package_table, _get_package_names_list
|
||||
from sqlmlutils import SQLPythonExecutor, SQLPackageManager, Scope
|
||||
|
||||
from conftest import connection
|
||||
from conftest import connection, scope
|
||||
|
||||
def _drop_all_ddl_packages(conn):
|
||||
def _drop_all_ddl_packages(conn, scope):
|
||||
pkgs = _get_sql_package_table(conn)
|
||||
for pkg in pkgs:
|
||||
try:
|
||||
SQLPackageManager(conn)._drop_sql_package(pkg['name'], scope=Scope.private_scope())
|
||||
except Exception:
|
||||
pass
|
||||
if(len(pkgs.index) > 0 ):
|
||||
for pkg in pkgs['name']:
|
||||
if pkg not in initial_list:
|
||||
try:
|
||||
SQLPackageManager(conn)._drop_sql_package(pkg, scope=scope)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def _get_initial_list(conn, scope):
|
||||
pkgs = _get_sql_package_table(conn)
|
||||
return pkgs['name']
|
||||
|
||||
pyexecutor = SQLPythonExecutor(connection)
|
||||
pkgmanager = SQLPackageManager(connection)
|
||||
_drop_all_ddl_packages(connection)
|
||||
|
||||
initial_list = _get_sql_package_table(connection)['name']
|
||||
|
||||
def _package_exists(module_name: str):
|
||||
mod = __import__(module_name)
|
||||
|
@ -36,46 +40,39 @@ def _package_no_exist(module_name: str):
|
|||
__import__(module_name)
|
||||
return True
|
||||
|
||||
|
||||
def test_install_tensorflow_and_keras():
|
||||
@pytest.mark.skip(reason="No version of tensorflow works with currently installed numpy (1.15.4)")
|
||||
def test_install_tensorflow():
|
||||
def use_tensorflow():
|
||||
import tensorflow as tf
|
||||
node1 = tf.constant(3.0, tf.float32)
|
||||
return str(node1.dtype)
|
||||
|
||||
try:
|
||||
pkgmanager.install("tensorflow", upgrade=True)
|
||||
val = pyexecutor.execute_function_in_sql(use_tensorflow)
|
||||
assert 'float32' in val
|
||||
|
||||
def use_keras():
|
||||
import keras
|
||||
|
||||
pkgmanager.install("tensorflow==1.1.0")
|
||||
val = pyexecutor.execute_function_in_sql(use_tensorflow)
|
||||
assert 'float32' in val
|
||||
|
||||
pkgmanager.install("keras")
|
||||
pyexecutor.execute_function_in_sql(use_keras)
|
||||
pkgmanager.uninstall("keras")
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, "keras")
|
||||
assert val
|
||||
|
||||
pkgmanager.uninstall("tensorflow")
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, "tensorflow")
|
||||
assert val
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
pkgmanager.uninstall("tensorflow")
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, "tensorflow")
|
||||
assert val
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
def test_install_many_packages():
|
||||
packages = ["multiprocessing_on_dill", "simplejson"]
|
||||
|
||||
try:
|
||||
for package in packages:
|
||||
pkgmanager.install(package, upgrade=True)
|
||||
val = pyexecutor.execute_function_in_sql(_package_exists, module_name=package)
|
||||
assert val
|
||||
|
||||
for package in packages:
|
||||
pkgmanager.install(package, upgrade=True)
|
||||
val = pyexecutor.execute_function_in_sql(_package_exists, module_name=package)
|
||||
assert val
|
||||
|
||||
pkgmanager.uninstall(package)
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package)
|
||||
assert val
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
pkgmanager.uninstall(package)
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package)
|
||||
assert val
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
def test_install_version():
|
||||
|
@ -86,99 +83,111 @@ def test_install_version():
|
|||
mod = __import__(module_name)
|
||||
return mod.__version__ == version
|
||||
|
||||
pkgmanager.install(package, version=v)
|
||||
val = pyexecutor.execute_function_in_sql(_package_version_exists, module_name=package, version=v)
|
||||
assert val
|
||||
try:
|
||||
pkgmanager.install(package, version=v)
|
||||
val = pyexecutor.execute_function_in_sql(_package_version_exists, module_name=package, version=v)
|
||||
assert val
|
||||
|
||||
pkgmanager.uninstall(package)
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package)
|
||||
assert val
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
pkgmanager.uninstall(package)
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package)
|
||||
assert val
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
def test_dependency_resolution():
|
||||
package = "multiprocessing_on_dill"
|
||||
package = "latex"
|
||||
|
||||
pkgmanager.install(package, upgrade=True)
|
||||
val = pyexecutor.execute_function_in_sql(_package_exists, module_name=package)
|
||||
assert val
|
||||
try:
|
||||
pkgmanager.install(package, upgrade=True)
|
||||
val = pyexecutor.execute_function_in_sql(_package_exists, module_name=package)
|
||||
assert val
|
||||
|
||||
pkgs = _get_package_names_list(connection)
|
||||
pkgs = _get_package_names_list(connection)
|
||||
|
||||
assert package in pkgs
|
||||
assert "pyreadline" in pkgs
|
||||
assert package in pkgs
|
||||
assert "funcsigs" in pkgs
|
||||
|
||||
pkgmanager.uninstall(package)
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package)
|
||||
assert val
|
||||
pkgmanager.uninstall(package)
|
||||
val = pyexecutor.execute_function_in_sql(_package_no_exist, module_name=package)
|
||||
assert val
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
def test_upgrade_parameter():
|
||||
|
||||
pkg = "cryptography"
|
||||
try:
|
||||
pkg = "cryptography"
|
||||
|
||||
# Get sql packages
|
||||
originalsqlpkgs = _get_sql_package_table(connection)
|
||||
first_version = "2.7"
|
||||
second_version = "2.8"
|
||||
|
||||
# Install package first so we can test upgrade param
|
||||
pkgmanager.install(pkg, version=first_version)
|
||||
|
||||
# Get sql packages
|
||||
originalsqlpkgs = _get_sql_package_table(connection)
|
||||
|
||||
output = io.StringIO()
|
||||
with redirect_stdout(output):
|
||||
pkgmanager.install(pkg, upgrade=False)
|
||||
assert "exists on server. Set upgrade to True to force upgrade." in output.getvalue()
|
||||
output = io.StringIO()
|
||||
with redirect_stdout(output):
|
||||
pkgmanager.install(pkg, upgrade=False, version=second_version)
|
||||
assert "exists on server. Set upgrade to True" in output.getvalue()
|
||||
|
||||
# Assert no additional packages were installed
|
||||
# Make sure nothing excess was accidentally installed
|
||||
|
||||
sqlpkgs = _get_sql_package_table(connection)
|
||||
assert len(sqlpkgs) == len(originalsqlpkgs)
|
||||
sqlpkgs = _get_sql_package_table(connection)
|
||||
assert len(sqlpkgs) == len(originalsqlpkgs)
|
||||
|
||||
#################
|
||||
#################
|
||||
|
||||
def check_version():
|
||||
import cryptography as cp
|
||||
return cp.__version__
|
||||
def check_version():
|
||||
import cryptography as cp
|
||||
return cp.__version__
|
||||
|
||||
oldversion = pyexecutor.execute_function_in_sql(check_version)
|
||||
oldversion = pyexecutor.execute_function_in_sql(check_version)
|
||||
|
||||
pkgmanager.install(pkg, upgrade=True)
|
||||
pkgmanager.install(pkg, upgrade=True, version=second_version)
|
||||
|
||||
afterinstall = _get_sql_package_table(connection)
|
||||
assert len(afterinstall) > len(originalsqlpkgs)
|
||||
afterinstall = _get_sql_package_table(connection)
|
||||
assert len(afterinstall) >= len(originalsqlpkgs)
|
||||
|
||||
version = pyexecutor.execute_function_in_sql(check_version)
|
||||
assert version > oldversion
|
||||
version = pyexecutor.execute_function_in_sql(check_version)
|
||||
assert version > oldversion
|
||||
|
||||
pkgmanager.uninstall("cryptography")
|
||||
pkgmanager.uninstall("cryptography")
|
||||
|
||||
sqlpkgs = _get_sql_package_table(connection)
|
||||
assert len(sqlpkgs) == len(afterinstall) - 1
|
||||
sqlpkgs = _get_sql_package_table(connection)
|
||||
assert len(sqlpkgs) == len(afterinstall) - 1
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
def test_install_abslpy():
|
||||
pkgmanager.install("absl-py")
|
||||
|
||||
def useit():
|
||||
import absl
|
||||
return absl.__file__
|
||||
|
||||
pyexecutor.execute_function_in_sql(useit)
|
||||
|
||||
pkgmanager.uninstall("absl-py")
|
||||
|
||||
def dontuseit():
|
||||
import pytest
|
||||
with pytest.raises(Exception):
|
||||
import absl
|
||||
|
||||
try:
|
||||
pkgmanager.install("absl-py")
|
||||
|
||||
pyexecutor.execute_function_in_sql(dontuseit)
|
||||
pyexecutor.execute_function_in_sql(useit)
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
pkgmanager.uninstall("absl-py")
|
||||
|
||||
pyexecutor.execute_function_in_sql(dontuseit)
|
||||
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Theano depends on a conda package libpython? lazylinker issue")
|
||||
def test_install_theano():
|
||||
pkgmanager.install("Theano")
|
||||
|
||||
|
@ -186,15 +195,17 @@ def test_install_theano():
|
|||
import theano.tensor as T
|
||||
return str(T)
|
||||
|
||||
pyexecutor.execute_function_in_sql(useit)
|
||||
try:
|
||||
pyexecutor.execute_function_in_sql(useit)
|
||||
|
||||
pkgmanager.uninstall("Theano")
|
||||
pkgmanager.uninstall("Theano")
|
||||
|
||||
pkgmanager.install("theano")
|
||||
pyexecutor.execute_function_in_sql(useit)
|
||||
pkgmanager.uninstall("theano")
|
||||
pkgmanager.install("theano")
|
||||
pyexecutor.execute_function_in_sql(useit)
|
||||
pkgmanager.uninstall("theano")
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
def test_already_installed_popular_ml_packages():
|
||||
|
@ -208,17 +219,18 @@ def test_already_installed_popular_ml_packages():
|
|||
|
||||
|
||||
def test_installing_popular_ml_packages():
|
||||
newpackages = ["plotly", "cntk", "gensim"]
|
||||
newpackages = ["plotly", "gensim"]
|
||||
|
||||
def checkit(pkgname):
|
||||
val = __import__(pkgname)
|
||||
return str(val)
|
||||
|
||||
for package in newpackages:
|
||||
pkgmanager.install(package)
|
||||
pyexecutor.execute_function_in_sql(checkit, pkgname=package)
|
||||
|
||||
_drop_all_ddl_packages(connection)
|
||||
try:
|
||||
for package in newpackages:
|
||||
pkgmanager.install(package)
|
||||
pyexecutor.execute_function_in_sql(checkit, pkgname=package)
|
||||
finally:
|
||||
_drop_all_ddl_packages(connection, scope)
|
||||
|
||||
|
||||
# TODO: find a bad pypi package to test this scenario
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sqlmlutils
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sqlmlutils
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sqlmlutils
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sqlmlutils
|
||||
|
|
|
@ -2,6 +2,6 @@ def foo(t1, t2, t3):
|
|||
return str(t1)+str(t2)
|
||||
|
||||
|
||||
res = foo(t1,t2,t3)
|
||||
param_str = foo(t1,t2,t3)
|
||||
|
||||
print("Testing output!")
|
|
@ -1,13 +1,14 @@
|
|||
# Copyright(c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright(c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pytest
|
||||
import sqlmlutils
|
||||
from contextlib import redirect_stdout
|
||||
from subprocess import Popen, PIPE, STDOUT
|
||||
from pandas import DataFrame
|
||||
import io
|
||||
import os
|
||||
import pytest
|
||||
import sqlmlutils
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
from subprocess import Popen, PIPE, STDOUT
|
||||
from pandas import DataFrame, set_option
|
||||
|
||||
from conftest import connection
|
||||
|
||||
|
@ -15,6 +16,11 @@ current_dir = os.path.dirname(__file__)
|
|||
script_dir = os.path.join(current_dir, "scripts")
|
||||
sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
||||
|
||||
# Prevent truncation of DataFrame when printing
|
||||
#
|
||||
set_option("display.max_colwidth", -1)
|
||||
set_option("display.max_columns", None)
|
||||
|
||||
|
||||
###################
|
||||
# No output tests #
|
||||
|
@ -23,6 +29,9 @@ sqlpy = sqlmlutils.SQLPythonExecutor(connection)
|
|||
def test_no_output():
|
||||
def my_func():
|
||||
print("blah blah blah")
|
||||
|
||||
# Test single quotes as well
|
||||
print('Hello')
|
||||
|
||||
name = "test_no_output"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
@ -30,9 +39,10 @@ def test_no_output():
|
|||
sqlpy.create_sproc_from_function(name, my_func)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
x = sqlpy.execute_sproc(name)
|
||||
x, outparams = sqlpy.execute_sproc(name)
|
||||
assert type(x) == DataFrame
|
||||
assert x.empty
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -49,14 +59,21 @@ def test_no_output_mixed_args():
|
|||
buf = io.StringIO()
|
||||
with redirect_stdout(buf):
|
||||
sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, val4=True)
|
||||
|
||||
assert "5 blah 15.5 True" in buf.getvalue()
|
||||
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
||||
|
||||
def test_no_output_mixed_args_in_df():
|
||||
def mixed(val1: int, val2: str, val3: float, val4: bool, val5: DataFrame):
|
||||
# Prevent truncation of DataFrame when printing
|
||||
#
|
||||
import pandas as pd
|
||||
pd.set_option("display.max_colwidth", -1)
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
print(val1, val2, val3, val4)
|
||||
print(val5)
|
||||
|
||||
|
@ -64,9 +81,11 @@ def test_no_output_mixed_args_in_df():
|
|||
sqlpy.drop_sproc(name)
|
||||
|
||||
sqlpy.create_sproc_from_function(name, mixed)
|
||||
|
||||
buf = io.StringIO()
|
||||
with redirect_stdout(buf):
|
||||
sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, val4=False, val5="SELECT TOP 2 * FROM airline5000")
|
||||
|
||||
assert "5 blah 15.5 False" in buf.getvalue()
|
||||
assert "ArrTime" in buf.getvalue()
|
||||
assert "CRSDepTime" in buf.getvalue()
|
||||
|
@ -79,18 +98,26 @@ def test_no_output_mixed_args_in_df():
|
|||
|
||||
|
||||
def test_no_output_mixed_args_in_df_in_params():
|
||||
def mixed(val1, val2, val3, val4, val5):
|
||||
print(val1, val2, val3, val5)
|
||||
print(val4)
|
||||
def mixed(val1: int, val2: str, val3: float, val4: bool, val5: DataFrame):
|
||||
# Prevent truncation of DataFrame when printing
|
||||
#
|
||||
import pandas as pd
|
||||
pd.set_option("display.max_colwidth", -1)
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
print(val1, val2, val3, val4)
|
||||
print(val5)
|
||||
|
||||
in_params = {"val1": int, "val2": str, "val3": float, "val4": DataFrame, "val5": bool}
|
||||
in_params = {"val1": int, "val2": str, "val3": float, "val4": bool, "val5": DataFrame}
|
||||
name = "test_no_output_mixed_args_in_df_in_params"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
||||
sqlpy.create_sproc_from_function(name=name, func=mixed, input_params=in_params)
|
||||
|
||||
buf = io.StringIO()
|
||||
with redirect_stdout(buf):
|
||||
sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, val4="SELECT TOP 2 * FROM airline5000", val5=False)
|
||||
sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5, val4=False, val5="SELECT TOP 2 * FROM airline5000")
|
||||
|
||||
assert "5 blah 15.5 False" in buf.getvalue()
|
||||
assert "ArrTime" in buf.getvalue()
|
||||
assert "CRSDepTime" in buf.getvalue()
|
||||
|
@ -118,8 +145,9 @@ def test_out_df_no_params():
|
|||
sqlpy.create_sproc_from_function(name, no_params)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
df = sqlpy.execute_sproc(name)
|
||||
df, outparams = sqlpy.execute_sproc(name)
|
||||
assert list(df.iloc[:,0] == [1, 2, 3, 4, 5])
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -140,9 +168,10 @@ def test_out_df_with_args():
|
|||
for values in vals:
|
||||
arg1 = values[0]
|
||||
arg2 = values[1]
|
||||
res = sqlpy.execute_sproc(name, arg1=arg1, arg2=arg2)
|
||||
assert res[0][0] == arg1
|
||||
assert res[1][0] == arg2
|
||||
res, outparams = sqlpy.execute_sproc(name, arg1=arg1, arg2=arg2)
|
||||
assert res.loc[0].iloc[0] == arg1
|
||||
assert res.loc[0].iloc[1] == arg2
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -158,10 +187,11 @@ def test_out_df_in_df():
|
|||
sqlpy.create_sproc_from_function(name, in_data)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
res = sqlpy.execute_sproc(name, in_df="SELECT TOP 10 * FROM airline5000")
|
||||
res, outparams = sqlpy.execute_sproc(name, in_df="SELECT TOP 10 * FROM airline5000")
|
||||
|
||||
assert type(res) == DataFrame
|
||||
assert res.shape == (10, 30)
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -169,7 +199,14 @@ def test_out_df_in_df():
|
|||
|
||||
def test_out_df_mixed_args_in_df():
|
||||
def mixed(val1: int, val2: str, val3: float, val4: DataFrame, val5: bool):
|
||||
# Prevent truncation of DataFrame when printing
|
||||
#
|
||||
import pandas as pd
|
||||
pd.set_option("display.max_colwidth", -1)
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
print(val1, val2, val3, val5)
|
||||
|
||||
if val5 and val1 == 5 and val2 == "blah" and val3 == 15.5:
|
||||
return val4
|
||||
else:
|
||||
|
@ -180,11 +217,12 @@ def test_out_df_mixed_args_in_df():
|
|||
|
||||
sqlpy.create_sproc_from_function(name, mixed)
|
||||
|
||||
res = sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5,
|
||||
res, outparams = sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5,
|
||||
val4="SELECT TOP 10 * FROM airline5000", val5=True)
|
||||
|
||||
assert type(res) == DataFrame
|
||||
assert res.shape == (10, 30)
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -192,6 +230,12 @@ def test_out_df_mixed_args_in_df():
|
|||
|
||||
def test_out_df_mixed_in_params_in_df():
|
||||
def mixed(val1, val2, val3, val4, val5):
|
||||
# Prevent truncation of DataFrame when printing
|
||||
#
|
||||
import pandas as pd
|
||||
pd.set_option("display.max_colwidth", -1)
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
print(val1, val2, val3, val5)
|
||||
if val5 and val1 == 5 and val2 == "blah" and val3 == 15.5:
|
||||
return val4
|
||||
|
@ -206,11 +250,12 @@ def test_out_df_mixed_in_params_in_df():
|
|||
sqlpy.create_sproc_from_function(name, mixed, input_params=input_params)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
res = sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5,
|
||||
res, outparams = sqlpy.execute_sproc(name, val1=5, val2="blah", val3=15.5,
|
||||
val4="SELECT TOP 10 * FROM airline5000", val5=True)
|
||||
|
||||
assert type(res) == DataFrame
|
||||
assert res.shape == (10, 30)
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -232,41 +277,42 @@ def test_out_of_order_args():
|
|||
v2 = "blah"
|
||||
v3 = 15.5
|
||||
v4 = "SELECT TOP 10 * FROM airline5000"
|
||||
res = sqlpy.execute_sproc(name, val5=False, val3=v3, val4=v4, val1=v1, val2=v2)
|
||||
|
||||
assert res[0][0] == v1
|
||||
assert res[1][0] == v2
|
||||
assert res[2][0] == v3
|
||||
assert not res[3][0]
|
||||
res, outparams = sqlpy.execute_sproc(name, val5=False, val3=v3, val4=v4, val1=v1, val2=v2)
|
||||
|
||||
assert res.loc[0].iloc[0] == v1
|
||||
assert res.loc[0].iloc[1] == v2
|
||||
assert res.loc[0].iloc[2] == v3
|
||||
assert not res.loc[0].iloc[3]
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
||||
|
||||
# TODO: Output Params execution not currently supported
|
||||
def test_in_param_out_param():
|
||||
def in_out(t1, t2, t3):
|
||||
# Prevent truncation of DataFrame when printing
|
||||
#
|
||||
import pandas as pd
|
||||
pd.set_option("display.max_colwidth", -1)
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
print(t2)
|
||||
print(t3)
|
||||
res = "Hello " + t1
|
||||
return {'out_df': t3, 'res': res}
|
||||
param_str = "Hello " + t1
|
||||
return {"out_df": t3, "param_str": param_str}
|
||||
|
||||
name = "test_in_param_out_param"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
||||
input_params = {"t1": str, "t2": int, "t3": DataFrame}
|
||||
output_params = {"res": str, "out_df": DataFrame}
|
||||
output_params = {"param_str": str, "out_df": DataFrame}
|
||||
|
||||
sqlpy.create_sproc_from_function(name, in_out, input_params=input_params, output_params=output_params)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
# Out params don't currently work so we use sqlcmd to test the output param sproc
|
||||
sql_str = "DECLARE @res nvarchar(max) EXEC test_in_param_out_param @t2 = 213, @t1 = N'Hello', " \
|
||||
"@t3 = N'select top 10 * from airline5000', @res = @res OUTPUT SELECT @res as N'@res'"
|
||||
p = Popen(["sqlcmd", "-S", connection.server, "-E", "-d", connection.database, "-Q", sql_str],
|
||||
shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT)
|
||||
output = p.stdout.read()
|
||||
assert "Hello Hello" in output.decode()
|
||||
res, outparams = sqlpy.execute_sproc(name, output_params = output_params, t1="Hello", t2 = 213, t3 = "select top 10 * from airline5000")
|
||||
assert "Hello Hello" in outparams["param_str"]
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -284,10 +330,11 @@ def test_in_df_out_df_dict():
|
|||
sqlpy.create_sproc_from_function(name, func, output_params=output_params)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
res = sqlpy.execute_sproc(name, in_df="SELECT TOP 10 * FROM airline5000")
|
||||
res, outparams = sqlpy.execute_sproc(name, in_df="SELECT TOP 10 * FROM airline5000")
|
||||
|
||||
assert type(res) == DataFrame
|
||||
assert res.shape == (10, 30)
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
@ -295,12 +342,12 @@ def test_in_df_out_df_dict():
|
|||
|
||||
################
|
||||
# Script Tests #
|
||||
################
|
||||
#################
|
||||
|
||||
def test_script_no_params():
|
||||
script = os.path.join(script_dir, "test_script_no_params.py")
|
||||
script = os.path.join(script_dir, "exec_script_no_params.py")
|
||||
|
||||
name = "test_script_no_params"
|
||||
name = "exec_script_no_params"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
||||
sqlpy.create_sproc_from_script(name, script)
|
||||
|
@ -309,6 +356,7 @@ def test_script_no_params():
|
|||
buf = io.StringIO()
|
||||
with redirect_stdout(buf):
|
||||
sqlpy.execute_sproc(name)
|
||||
|
||||
assert "No Inputs" in buf.getvalue()
|
||||
assert "Required" in buf.getvalue()
|
||||
assert "Testing output!" in buf.getvalue()
|
||||
|
@ -319,9 +367,9 @@ def test_script_no_params():
|
|||
|
||||
|
||||
def test_script_no_out_params():
|
||||
script = os.path.join(script_dir, "test_script_no_out_params.py")
|
||||
script = os.path.join(script_dir, "exec_script_no_out_params.py")
|
||||
|
||||
name = "test_script_no_out_params"
|
||||
name = "exec_script_no_out_params"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
||||
input_params = {"t1": str, "t2": str, "t3": int}
|
||||
|
@ -332,6 +380,7 @@ def test_script_no_out_params():
|
|||
buf = io.StringIO()
|
||||
with redirect_stdout(buf):
|
||||
sqlpy.execute_sproc(name, t1="Hello", t2="World", t3=312)
|
||||
|
||||
assert "HelloWorld" in buf.getvalue()
|
||||
assert "312" in buf.getvalue()
|
||||
assert "Testing output!" in buf.getvalue()
|
||||
|
@ -341,9 +390,9 @@ def test_script_no_out_params():
|
|||
|
||||
|
||||
def test_script_out_df():
|
||||
script = os.path.join(script_dir, "test_script_sproc_out_df.py")
|
||||
script = os.path.join(script_dir, "exec_script_sproc_out_df.py")
|
||||
|
||||
name = "test_script_out_df"
|
||||
name = "exec_script_out_df"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
||||
input_params = {"t1": str, "t2": int, "t3": DataFrame}
|
||||
|
@ -351,36 +400,31 @@ def test_script_out_df():
|
|||
sqlpy.create_sproc_from_script(name, script, input_params)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
res = sqlpy.execute_sproc(name, t1="Hello", t2=2313, t3="SELECT TOP 10 * FROM airline5000")
|
||||
|
||||
res, outparams = sqlpy.execute_sproc(name, t1="Hello", t2=2313, t3="SELECT TOP 10 * FROM airline5000")
|
||||
|
||||
assert type(res) == DataFrame
|
||||
assert res.shape == (10, 30)
|
||||
assert not outparams
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
||||
|
||||
#TODO: Output Params execution not currently supported
|
||||
def test_script_out_param():
|
||||
script = os.path.join(script_dir, "test_script_out_param.py")
|
||||
script = os.path.join(script_dir, "exec_script_out_param.py")
|
||||
|
||||
name = "test_script_out_param"
|
||||
name = "exec_script_out_param"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
||||
input_params = {"t1": str, "t2": int, "t3": DataFrame}
|
||||
output_params = {"res": str}
|
||||
output_params = {"param_str": str}
|
||||
|
||||
sqlpy.create_sproc_from_script(name, script, input_params, output_params)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
# Out params don't currently work so we use sqlcmd to test the output param sproc
|
||||
sql_str = "DECLARE @res nvarchar(max) EXEC test_script_out_param @t2 = 123, @t1 = N'Hello', " \
|
||||
"@t3 = N'select top 10 * from airline5000', @res = @res OUTPUT SELECT @res as N'@res'"
|
||||
p = Popen(["sqlcmd", "-S", connection.server, "-E", "-d", connection.database, "-Q", sql_str],
|
||||
shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT)
|
||||
output = p.stdout.read()
|
||||
assert "Hello123" in output.decode()
|
||||
|
||||
|
||||
res, outparams = sqlpy.execute_sproc(name, output_params = output_params, t1="Hello", t2 = 123, t3 = "select top 10 * from airline5000")
|
||||
assert "Hello123" in outparams["param_str"]
|
||||
|
||||
sqlpy.drop_sproc(name)
|
||||
assert not sqlpy.check_sproc(name)
|
||||
|
||||
|
@ -398,8 +442,10 @@ def test_execute_bad_param_types():
|
|||
|
||||
def func(input1: bool):
|
||||
pass
|
||||
|
||||
name = "BadInput"
|
||||
sqlpy.drop_sproc(name)
|
||||
|
||||
sqlpy.create_sproc_from_function(name, func)
|
||||
assert sqlpy.check_sproc(name)
|
||||
|
||||
|
|
13
README.md
13
README.md
|
@ -12,17 +12,20 @@ Currently, only the R version of sqlmlutils is supported in Azure SQL Database.
|
|||
To install sqlmlutils, follow the instructions below for Python and R, respectively.
|
||||
|
||||
Python:
|
||||
1. If your client is a Linux machine, you can skip this step.
|
||||
If your client is a Windows machine: go to https://www.lfd.uci.edu/~gohlke/pythonlibs/#pymssql and download the correct version of pymssql for your client. Run ```pip install pymssql-2.1.4.dev5-cpXX-cpXXm-win_amd64.whl``` on that file to install pymssql.
|
||||
2. Run
|
||||
To install from PyPI:
|
||||
Run
|
||||
```
|
||||
pip install sqlmlutils
|
||||
```
|
||||
To install from file:
|
||||
```
|
||||
pip install Python/dist/sqlmlutils-1.0.0.zip
|
||||
```
|
||||
|
||||
R:
|
||||
```
|
||||
R -e "install.packages('RODBCext', repos='https://cran.microsoft.com')"
|
||||
R CMD INSTALL sqlmlutils_0.7.1.zip
|
||||
R CMD INSTALL R/dist/sqlmlutils_0.7.1.zip
|
||||
```
|
||||
|
||||
# Details
|
||||
|
@ -44,6 +47,6 @@ The goal of this utility is to allow users to create and execute stored procedur
|
|||
|
||||
## Package Management
|
||||
|
||||
##### Package management with sqlmlutils is supported in SQL Server 2019 CTP 2.4 and later.
|
||||
##### R and Python package management with sqlmlutils is supported in SQL Server 2019 CTP 2.4 and later.
|
||||
|
||||
With package management users can install packages to a remote SQL database from a client machine. The packages are downloaded on the client and then sent over to SQL databases where they will be installed into library folders. The folders are per-database so packages will always be installed and made available for a specific database. The package management APIs provided a PUBLIC and PRIVATE folders. Packages in the PUBLIC folder are accessible to all database users. Packages in the PRIVATE folder are only accessible by the user who installed the package.
|
||||
|
|
Загрузка…
Ссылка в новой задаче