diff --git a/src/semantic_code_search/cli.py b/src/semantic_code_search/cli.py index 8dd44ed..27ef53d 100644 --- a/src/semantic_code_search/cli.py +++ b/src/semantic_code_search/cli.py @@ -52,6 +52,8 @@ def main(): type=str, required=False, help='Name or path of the model to use') parser.add_argument('-d', '--embed', action='store_true', default=False, required=False, help='(Re)create the embeddings index for codebase') + parser.add_argument('-en', '--encoding', type=str, default='utf-8', + required=False, help='Encoding type for codebase') parser.add_argument('-b', '--batch-size', metavar='BS', type=int, default=32, help='Batch size for embeddings generation') diff --git a/src/semantic_code_search/embed.py b/src/semantic_code_search/embed.py index d82d030..6450af4 100644 --- a/src/semantic_code_search/embed.py +++ b/src/semantic_code_search/embed.py @@ -59,18 +59,18 @@ def _extract_functions(nodes, fp, file_content, relevant_node_types): return out -def _get_repo_functions(root, supported_file_extensions, relevant_node_types): +def _get_repo_functions(root, supported_file_extensions, relevant_node_types, encoding): functions = [] print('Extracting functions from {}'.format(root)) for fp in tqdm([root + '/' + f for f in os.popen('git -C {} ls-files'.format(root)).read().split('\n')]): if not os.path.isfile(fp): continue - with open(fp, 'r') as f: + with open(fp, 'r', encoding=encoding) as f: lang = supported_file_extensions.get(fp[fp.rfind('.'):]) if lang: parser = get_parser(lang) file_content = f.read() - tree = parser.parse(bytes(file_content, 'utf8')) + tree = parser.parse(file_content.encode(encoding)) all_nodes = list(_traverse_tree(tree.root_node)) functions.extend(_extract_functions( all_nodes, fp, file_content, relevant_node_types)) @@ -81,7 +81,7 @@ def do_embed(args, model): nodes_to_extract = ['function_definition', 'method_definition', 'function_declaration', 'method_declaration'] functions = _get_repo_functions( - args.path_to_repo, _supported_file_extensions(), nodes_to_extract) + args.path_to_repo, _supported_file_extensions(), nodes_to_extract, args.encoding) if not functions: print('No supported languages found in {}. Exiting'.format(args.path_to_repo))