Add syntax highlighting to output
This commit is contained in:
parent
63c38a371f
commit
19f430e408
85
assistant.py
85
assistant.py
|
|
@ -5,6 +5,10 @@ import re
|
||||||
import pyperclip
|
import pyperclip
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
import pygments
|
||||||
|
from pygments.lexers import get_lexer_by_name
|
||||||
|
from pygments.formatters import TerminalFormatter
|
||||||
|
|
||||||
|
|
||||||
model = 'llama3.1:8b-instruct-q8_0'
|
model = 'llama3.1:8b-instruct-q8_0'
|
||||||
temp = 0.2
|
temp = 0.2
|
||||||
|
|
@ -12,6 +16,49 @@ temp = 0.2
|
||||||
pattern = r'```[a-z]*\n[\s\S]*?\n```'
|
pattern = r'```[a-z]*\n[\s\S]*?\n```'
|
||||||
line_pattern = r'`[a-z]*[\s\S]*?`'
|
line_pattern = r'`[a-z]*[\s\S]*?`'
|
||||||
|
|
||||||
|
def highlight_code(language_name, code):
|
||||||
|
# Check if the language is specified in the first line
|
||||||
|
lexer_name = language_name
|
||||||
|
if lexer_name == None:
|
||||||
|
lines = code.split('\n')
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().startswith('```'):
|
||||||
|
lexer_name = line.strip().split('```')[1].strip()
|
||||||
|
break
|
||||||
|
elif line.strip().startswith('lang:'):
|
||||||
|
lexer_name = line.strip().split(':')[1].strip()
|
||||||
|
break
|
||||||
|
|
||||||
|
if lexer_name:
|
||||||
|
try:
|
||||||
|
# Try to get the lexer by name
|
||||||
|
lexer = get_lexer_by_name(lexer_name)
|
||||||
|
except ValueError:
|
||||||
|
# If the lexer is not found, guess it
|
||||||
|
lexer = guess_lexer(code.split('\n')[1:-1])
|
||||||
|
if not lexer:
|
||||||
|
# If no lexer is guessed, default to bash
|
||||||
|
lexer = get_lexer_by_name('bash')
|
||||||
|
else:
|
||||||
|
# If no language is specified, guess the lexer
|
||||||
|
lexer = guess_lexer(code.split('\n')[1:-1])
|
||||||
|
if not lexer:
|
||||||
|
# If no lexer is guessed, default to bash
|
||||||
|
lexer = get_lexer_by_name('bash')
|
||||||
|
|
||||||
|
formatter = TerminalFormatter()
|
||||||
|
|
||||||
|
just_code = ''
|
||||||
|
lines = code.split('\n')
|
||||||
|
# Just a single line of code, without code blocks around it
|
||||||
|
if len(lines) == 2:
|
||||||
|
just_code = code
|
||||||
|
else:
|
||||||
|
just_code = '\n'.join(code.split('\n')[1:-1])
|
||||||
|
|
||||||
|
highlighted_code = pygments.highlight(just_code, lexer, formatter)
|
||||||
|
return highlighted_code
|
||||||
|
|
||||||
def extract_code_block(markdown_text):
|
def extract_code_block(markdown_text):
|
||||||
# Use the regular expression pattern to find all matches in the markdown text
|
# Use the regular expression pattern to find all matches in the markdown text
|
||||||
matches = re.finditer(pattern, markdown_text)
|
matches = re.finditer(pattern, markdown_text)
|
||||||
|
|
@ -21,8 +68,11 @@ def extract_code_block(markdown_text):
|
||||||
for match in matches:
|
for match in matches:
|
||||||
code_block = match.group(0)
|
code_block = match.group(0)
|
||||||
|
|
||||||
# Add the code block to the list of code blocks
|
highlighted_code = highlight_code(None, code_block)
|
||||||
code_blocks.append('\n'.join(code_block.split('\n')[1:-1]))
|
|
||||||
|
# Add the highlighted code block to the list of code blocks
|
||||||
|
code_blocks.append(highlighted_code)
|
||||||
|
|
||||||
|
|
||||||
if len(code_blocks) == 0:
|
if len(code_blocks) == 0:
|
||||||
line_matches = re.finditer(line_pattern, markdown_text)
|
line_matches = re.finditer(line_pattern, markdown_text)
|
||||||
|
|
@ -55,10 +105,32 @@ def chat(message, stream=True):
|
||||||
stream=stream
|
stream=stream
|
||||||
)
|
)
|
||||||
result = ''
|
result = ''
|
||||||
|
language = ''
|
||||||
|
large_chunk = []
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
if stream:
|
if stream:
|
||||||
print(chunk['message']['content'], end='', flush=True)
|
text = chunk['message']['content']
|
||||||
result += chunk['message']['content']
|
large_chunk.append(text)
|
||||||
|
large_text = ''.join(large_chunk)
|
||||||
|
|
||||||
|
# Syntax highlight if possible
|
||||||
|
# check if highlighting can be done
|
||||||
|
if ('\n' in large_text) and ('```' in large_text):
|
||||||
|
language = large_text.split('```')[1].split('\n')[0]
|
||||||
|
print(large_text, end='', flush=True)
|
||||||
|
large_chunk = []
|
||||||
|
large_text = ''
|
||||||
|
if language == '':
|
||||||
|
language = None
|
||||||
|
|
||||||
|
if '\n' in large_text:
|
||||||
|
output = large_text
|
||||||
|
if language != None:
|
||||||
|
output = highlight_code(language, output)
|
||||||
|
print(output, end='', flush=True)
|
||||||
|
large_chunk = []
|
||||||
|
# print(highlighted_text, end='', flush=True)
|
||||||
|
result += text
|
||||||
if not stream:
|
if not stream:
|
||||||
result = completion['message']['content']
|
result = completion['message']['content']
|
||||||
if stream:
|
if stream:
|
||||||
|
|
@ -98,7 +170,7 @@ def arg_shell(args):
|
||||||
query += input('> ')
|
query += input('> ')
|
||||||
result = chat(query, False)
|
result = chat(query, False)
|
||||||
result = blocks[0] if len(blocks := extract_code_block(result)) else result
|
result = blocks[0] if len(blocks := extract_code_block(result)) else result
|
||||||
print(result)
|
print(blocks)
|
||||||
copy_string_to_clipboard(result)
|
copy_string_to_clipboard(result)
|
||||||
|
|
||||||
def handle_piped_input(args):
|
def handle_piped_input(args):
|
||||||
|
|
@ -121,7 +193,8 @@ def handle_non_piped_input(args):
|
||||||
exit()
|
exit()
|
||||||
if args.follow_up:
|
if args.follow_up:
|
||||||
user_input = arg_follow_up(args)
|
user_input = arg_follow_up(args)
|
||||||
chat(user_input)
|
result = chat(user_input)
|
||||||
|
code_blocks = extract_code_block(result)
|
||||||
exit()
|
exit()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue