Skip to content

Commit

Permalink
[sparksql] Improve session reuse and fix corner cases (#2851)
Browse files Browse the repository at this point in the history
- Improve session handling
- Fix failing corner cases
- Add checks for different session states
- Cancel statement improvements
- Fix failing UTs
  • Loading branch information
Harshg999 authored May 18, 2022
1 parent cba12fb commit 33d4f05
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 29 deletions.
3 changes: 3 additions & 0 deletions apps/spark/src/spark/livy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def close(self, uuid):
def get_batches(self):
return self._root.get('batches')

def cancel_statement(self, session, statement_id):
return self._root.post('sessions/%s/statements/%s/cancel' % (session, statement_id))

def submit_batch(self, properties):
properties['proxyUser'] = self.user
return self._root.post('batches', data=json.dumps(properties), contenttype=_JSON_CONTENT_TYPE)
Expand Down
121 changes: 92 additions & 29 deletions desktop/libs/notebook/src/notebook/connectors/spark_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,26 @@ def _get_session_key(self):
}


def _check_session(self, session):
'''
Check if the session is actually present and its state is healthy.
'''
api = self.get_api()
try:
session_present = api.get_session(session['id'])
except Exception as e:
session_present = None

if session_present and session_present['state'] not in ('dead', 'shutting_down', 'error', 'killed'):
return session_present


def create_session(self, lang='scala', properties=None):
api = self.get_api()
session_key = self._get_session_key()

if SESSIONS.get(session_key):
# Checking if the session is actually present to avoid stale value
session_present = api.get_session(SESSIONS[session_key]['id'])
session_present = self._check_session(SESSIONS[session_key])
if session_present:
return SESSIONS[session_key]

Expand Down Expand Up @@ -161,15 +174,18 @@ def execute(self, notebook, snippet):
api = self.get_api()
session = _get_snippet_session(notebook, snippet)

response = self._execute(api, session, snippet['statement'])
response = self._execute(api, session, snippet.get('type'), snippet['statement'])
return response


def _execute(self, api, session, statement):
def _execute(self, api, session, snippet_type, statement):
session_key = self._get_session_key()

if session['id'] is None and SESSIONS.get(session_key) is not None:
session = SESSIONS[session_key]
if not session or not self._check_session(session):
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
session = SESSIONS[session_key]
else:
session = self.create_session(snippet_type)

try:
response = api.submit_statement(session['id'], statement)
Expand All @@ -191,6 +207,8 @@ def check_status(self, notebook, snippet):
session = _get_snippet_session(notebook, snippet)
cell = snippet['result']['handle']['id']

session = self._handle_session_health_check(session)

try:
response = api.fetch_data(session['id'], cell)
return {
Expand All @@ -209,6 +227,8 @@ def fetch_result(self, notebook, snippet, rows, start_over):
session = _get_snippet_session(notebook, snippet)
cell = snippet['result']['handle']['id']

session = self._handle_session_health_check(session)

response = self._fetch_result(api, session, cell, start_over)
return response

Expand Down Expand Up @@ -279,16 +299,43 @@ def _fetch_result(self, api, session, cell, start_over):
def cancel(self, notebook, snippet):
api = self.get_api()
session = _get_snippet_session(notebook, snippet)
response = api.cancel(session['id'])

session = self._handle_session_health_check(session)

try:
response = api.cancel(session['id'])
except Exception as e:
message = force_unicode(str(e)).lower()
LOG.debug(message)

return {'status': 0}


def get_log(self, notebook, snippet, startFrom=0, size=None):
response = {'status': 0}
api = self.get_api()
session = _get_snippet_session(notebook, snippet)

return api.get_log(session['id'], startFrom=startFrom, size=size)
session = self._handle_session_health_check(session)
try:
response = api.get_log(session['id'], startFrom=startFrom, size=size)
except RestException as e:
message = force_unicode(str(e)).lower()
LOG.debug(message)

return response


def _handle_session_health_check(self, session):
session_key = self._get_session_key()

if not session or not self._check_session(session):
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
session = SESSIONS[session_key]
else:
raise PopupException(_("Session expired. Please create new session and try again."))

return session


def close_statement(self, notebook, snippet): # Individual statements cannot be closed
Expand Down Expand Up @@ -327,9 +374,9 @@ def get_jobs(self, notebook, snippet, logs):

def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
response = {}

# As booting a new SQL session is slow and we don't send the id of the current one in /autocomplete
# we could implement this by introducing an API cache per user similarly to SqlAlchemy.

api = self.get_api()
session_key = self._get_session_key()

Expand All @@ -338,14 +385,17 @@ def autocomplete(self, snippet, database=None, table=None, column=None, nested=N
if SESSIONS.get(session_key):
self._close_unused_sessions()

session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
session = SESSIONS[session_key]
else:
session = self.create_session(snippet.get('type'))

if database is None:
response['databases'] = self._show_databases(api, session)
response['databases'] = self._show_databases(api, session, snippet.get('type'))
elif table is None:
response['tables_meta'] = self._show_tables(api, session, database)
response['tables_meta'] = self._show_tables(api, session, snippet.get('type'), database)
elif column is None:
columns = self._get_columns(api, session, database, table)
columns = self._get_columns(api, session, snippet.get('type'), database, table)
response['columns'] = [col['name'] for col in columns]
response['extended_columns'] = [{
'comment': col.get('comment'),
Expand All @@ -360,52 +410,62 @@ def autocomplete(self, snippet, database=None, table=None, column=None, nested=N

def _close_unused_sessions(self):
'''
Closes all unsused Livy sessions for a particular user to free up session resources.
Closes all unused Livy sessions for a particular user to free up session resources.
'''
api = self.get_api()
session_key = self._get_session_key()

all_sessions = api.get_sessions()
for session in all_sessions['sessions']:
if session['owner'] == self.user.username and session['id'] != SESSIONS[session_key]['id']:
self.close_session(session)
all_session = {}
try:
all_sessions = api.get_sessions()
except Exception as e:
message = force_unicode(str(e)).lower()
LOG.debug(message)

if all_sessions:
for session in all_sessions['sessions']:
if session['owner'] == self.user.username and session['id'] != SESSIONS[session_key]['id'] and \
session['state'] in ('idle', 'shutting_down', 'error', 'dead', 'killed'):
self.close_session(session)


def _check_status_and_fetch_result(self, api, session, execute_resp):
check_status = api.fetch_data(session['id'], execute_resp['id'])

while check_status['state'] in ['running', 'waiting']:
count = 0
while check_status['state'] in ['running', 'waiting'] and count < 120:
check_status = api.fetch_data(session['id'], execute_resp['id'])
count += 1
time.sleep(1)

if check_status['state'] == 'available':
return self._fetch_result(api, session, execute_resp['id'], start_over=True)


def _show_databases(self, api, session):
show_db_execute = self._execute(api, session, 'SHOW DATABASES')
def _show_databases(self, api, session, snippet_type):
show_db_execute = self._execute(api, session, snippet_type, 'SHOW DATABASES')
db_list = self._check_status_and_fetch_result(api, session, show_db_execute)

if db_list:
return [db[0] for db in db_list['data']]


def _show_tables(self, api, session, database):
use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
def _show_tables(self, api, session, snippet_type, database):
use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)

show_tables_execute = self._execute(api, session, 'SHOW TABLES')
show_tables_execute = self._execute(api, session, snippet_type, 'SHOW TABLES')
tables_list = self._check_status_and_fetch_result(api, session, show_tables_execute)

if tables_list:
return [table[1] for table in tables_list['data']]


def _get_columns(self, api, session, database, table):
use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
def _get_columns(self, api, session, snippet_type, database, table):
use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)

describe_tables_execute = self._execute(api, session, 'DESCRIBE %(table)s' % {'table': table})
describe_tables_execute = self._execute(api, session, snippet_type, 'DESCRIBE %(table)s' % {'table': table})
columns_list = self._check_status_and_fetch_result(api, session, describe_tables_execute)

if columns_list:
Expand All @@ -425,11 +485,14 @@ def get_sample_data(self, snippet, database=None, table=None, column=None, is_as
if SESSIONS.get(session_key):
self._close_unused_sessions()

session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
session = SESSIONS[session_key]
else:
session = self.create_session(snippet.get('type'))

statement = self._get_select_query(database, table, column, operation)

sample_execute = self._execute(api, session, statement)
sample_execute = self._execute(api, session, snippet.get('type'), statement)
sample_result = self._check_status_and_fetch_result(api, session, sample_execute)

response = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def test_execute(self):
return_value={'id': 'test_id'}
)
)
self.api._check_session = Mock(return_value={'id': '1'})

response = self.api.execute(notebook, snippet)
assert_equal(response['id'], 'test_id')
Expand Down Expand Up @@ -197,6 +198,7 @@ def test_check_status(self):
return_value={'state': 'test_state'}
)
)
self.api._handle_session_health_check = Mock(return_value={'id': '1'})

response = self.api.check_status(notebook, snippet)
assert_equal(response['status'], 'test_state')
Expand Down

0 comments on commit 33d4f05

Please sign in to comment.