Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table

conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
sql: str = ""
if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks":
if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks" or ds.type == "hive":
sql = f"""SELECT `{"`, `".join(fields)}` FROM `{data.table.table_name}`
{where}
LIMIT 100"""
Expand Down
43 changes: 11 additions & 32 deletions backend/apps/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from sqlalchemy.pool import NullPool
from pyhive import hive


try:
if os.path.exists(settings.ORACLE_CLIENT_PATH):
oracledb.init_oracle_client(
Expand Down Expand Up @@ -159,9 +158,10 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine:
poolclass=NullPool)
elif equals_ignore_case(ds.type, 'oracle'):
engine = create_engine(get_uri(ds), poolclass=NullPool)
elif equals_ignore_case(ds.type, 'mysql'): # mysql
elif equals_ignore_case(ds.type, 'mysql'): # mysql
ssl_mode = {"require": True} if conf.ssl else None
engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout, "ssl": ssl_mode}, poolclass=NullPool)
engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout, "ssl": ssl_mode},
poolclass=NullPool)
elif equals_ignore_case(ds.type, 'sqlite'):
engine = create_engine(get_uri(ds), connect_args={"check_same_thread": False}, poolclass=NullPool)
else: # ck
Expand Down Expand Up @@ -271,7 +271,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
if is_raise:
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
return False

elif equals_ignore_case(ds.type, 'es'):
es_conn = get_es_connect(conf)
if es_conn.ping():
Expand Down Expand Up @@ -314,7 +314,7 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema):
# conf.timeout = 10
db = DB.get_db(ds.type)
sql = get_version_sql(ds, conf)
if equals_ignore_case(ds.type, 'sqlite'):
if not sql:
return ''
try:
if db.connect_type == ConnectType.sqlalchemy:
Expand Down Expand Up @@ -397,30 +397,6 @@ def get_schema(ds: CoreDatasource):
res = cursor.fetchall()
res_list = [item[0] for item in res]
return res_list
elif equals_ignore_case(ds.type, 'hive'):
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
database=conf.database, **extra_config_dict)
cursor = conn.cursor()
cursor.execute('SHOW DATABASES')
res = cursor.fetchall()
res_list = [item[0] for item in res]
cursor.close()
conn.close()
return res_list
elif equals_ignore_case(ds.type, 'doris', 'starrocks'):
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=10,
read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute('SHOW DATABASES')
res = cursor.fetchall()
res_list = [item[0] for item in res]
return res_list
elif equals_ignore_case(ds.type, 'ck'):
with get_session(ds) as session:
with session.execute(text('SHOW DATABASES')) as result:
res = result.fetchall()
res_list = [item[0] for item in res]
return res_list


def get_tables(ds: CoreDatasource):
Expand Down Expand Up @@ -456,7 +432,8 @@ def get_tables(ds: CoreDatasource):
ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {}
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor:
read_timeout=conf.timeout, **extra_config_dict,
**ssl_args) as conn, conn.cursor() as cursor:
cursor.execute(sql, (sql_param,))
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
Expand Down Expand Up @@ -527,7 +504,8 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {}
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor:
read_timeout=conf.timeout, **extra_config_dict,
**ssl_args) as conn, conn.cursor() as cursor:
cursor.execute(sql, (p1, p2))
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
Expand Down Expand Up @@ -684,7 +662,8 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {}
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor:
read_timeout=conf.timeout, **extra_config_dict,
**ssl_args) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql)
res = cursor.fetchall()
Expand Down
4 changes: 1 addition & 3 deletions backend/apps/db/db_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def get_version_sql(ds: CoreDatasource, conf: DatasourceConf):
return """
SELECT * FROM v$version
"""
elif equals_ignore_case(ds.type, "redshift"):
return ''
elif equals_ignore_case(ds.type, "sqlite"):
elif equals_ignore_case(ds.type, "redshift", "sqlite", "hive"):
return ''


Expand Down
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dependencies = [
"ldap3>=2.9.1",
"sqlglot>=28.6.0",
"numpy==2.3.5",
"pyhive[hive]>=0.7.0",
"pyhive[hive_pure_sasl]>=0.7.0",
"thrift-sasl"
]

Expand Down
Loading