Skip to content

Commit 09adb50

Browse files
committed
fix: Fix the data source-related issues in the PR
1 parent 5df4654 commit 09adb50

4 files changed

Lines changed: 14 additions & 37 deletions

File tree

backend/apps/datasource/crud/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
329329

330330
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
331331
sql: str = ""
332-
if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks":
332+
if ds.type == "mysql" or ds.type == "doris" or ds.type == "starrocks" or ds.type == "hive":
333333
sql = f"""SELECT `{"`, `".join(fields)}` FROM `{data.table.table_name}`
334334
{where}
335335
LIMIT 100"""

backend/apps/db/db.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from sqlalchemy.pool import NullPool
3939
from pyhive import hive
4040

41-
4241
try:
4342
if os.path.exists(settings.ORACLE_CLIENT_PATH):
4443
oracledb.init_oracle_client(
@@ -159,9 +158,10 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine:
159158
poolclass=NullPool)
160159
elif equals_ignore_case(ds.type, 'oracle'):
161160
engine = create_engine(get_uri(ds), poolclass=NullPool)
162-
elif equals_ignore_case(ds.type, 'mysql'): # mysql
161+
elif equals_ignore_case(ds.type, 'mysql'): # mysql
163162
ssl_mode = {"require": True} if conf.ssl else None
164-
engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout, "ssl": ssl_mode}, poolclass=NullPool)
163+
engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout, "ssl": ssl_mode},
164+
poolclass=NullPool)
165165
elif equals_ignore_case(ds.type, 'sqlite'):
166166
engine = create_engine(get_uri(ds), connect_args={"check_same_thread": False}, poolclass=NullPool)
167167
else: # ck
@@ -271,7 +271,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
271271
if is_raise:
272272
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
273273
return False
274-
274+
275275
elif equals_ignore_case(ds.type, 'es'):
276276
es_conn = get_es_connect(conf)
277277
if es_conn.ping():
@@ -314,7 +314,7 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema):
314314
# conf.timeout = 10
315315
db = DB.get_db(ds.type)
316316
sql = get_version_sql(ds, conf)
317-
if equals_ignore_case(ds.type, 'sqlite'):
317+
if not sql:
318318
return ''
319319
try:
320320
if db.connect_type == ConnectType.sqlalchemy:
@@ -397,30 +397,6 @@ def get_schema(ds: CoreDatasource):
397397
res = cursor.fetchall()
398398
res_list = [item[0] for item in res]
399399
return res_list
400-
elif equals_ignore_case(ds.type, 'hive'):
401-
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
402-
database=conf.database, **extra_config_dict)
403-
cursor = conn.cursor()
404-
cursor.execute('SHOW DATABASES')
405-
res = cursor.fetchall()
406-
res_list = [item[0] for item in res]
407-
cursor.close()
408-
conn.close()
409-
return res_list
410-
elif equals_ignore_case(ds.type, 'doris', 'starrocks'):
411-
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
412-
port=conf.port, db=conf.database, connect_timeout=10,
413-
read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
414-
cursor.execute('SHOW DATABASES')
415-
res = cursor.fetchall()
416-
res_list = [item[0] for item in res]
417-
return res_list
418-
elif equals_ignore_case(ds.type, 'ck'):
419-
with get_session(ds) as session:
420-
with session.execute(text('SHOW DATABASES')) as result:
421-
res = result.fetchall()
422-
res_list = [item[0] for item in res]
423-
return res_list
424400

425401

426402
def get_tables(ds: CoreDatasource):
@@ -456,7 +432,8 @@ def get_tables(ds: CoreDatasource):
456432
ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {}
457433
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
458434
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
459-
read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor:
435+
read_timeout=conf.timeout, **extra_config_dict,
436+
**ssl_args) as conn, conn.cursor() as cursor:
460437
cursor.execute(sql, (sql_param,))
461438
res = cursor.fetchall()
462439
res_list = [TableSchema(*item) for item in res]
@@ -527,7 +504,8 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
527504
ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {}
528505
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
529506
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
530-
read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor:
507+
read_timeout=conf.timeout, **extra_config_dict,
508+
**ssl_args) as conn, conn.cursor() as cursor:
531509
cursor.execute(sql, (p1, p2))
532510
res = cursor.fetchall()
533511
res_list = [ColumnSchema(*item) for item in res]
@@ -684,7 +662,8 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
684662
ssl_args = {'ssl': {'ssl_mode': 'REQUIRE'}} if conf.ssl else {}
685663
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
686664
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
687-
read_timeout=conf.timeout, **extra_config_dict, **ssl_args) as conn, conn.cursor() as cursor:
665+
read_timeout=conf.timeout, **extra_config_dict,
666+
**ssl_args) as conn, conn.cursor() as cursor:
688667
try:
689668
cursor.execute(sql)
690669
res = cursor.fetchall()

backend/apps/db/db_sql.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def get_version_sql(ds: CoreDatasource, conf: DatasourceConf):
2929
return """
3030
SELECT * FROM v$version
3131
"""
32-
elif equals_ignore_case(ds.type, "redshift"):
33-
return ''
34-
elif equals_ignore_case(ds.type, "sqlite"):
32+
elif equals_ignore_case(ds.type, "redshift", "sqlite", "hive"):
3533
return ''
3634

3735

backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ dependencies = [
5454
"ldap3>=2.9.1",
5555
"sqlglot>=28.6.0",
5656
"numpy==2.3.5",
57-
"pyhive[hive]>=0.7.0",
57+
"pyhive[hive_pure_sasl]>=0.7.0",
5858
"thrift-sasl"
5959
]
6060

0 commit comments

Comments
 (0)