Add reflection option
This commit is contained in:
parent
c0a9377849
commit
8007943581
68
assistant.py
68
assistant.py
|
|
@ -136,6 +136,36 @@ def chat(message, stream=True):
|
||||||
history.append({"role": 'assistant', 'content': result})
|
history.append({"role": 'assistant', 'content': result})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def chat2(args, user_input, stream=False):
|
||||||
|
if args.reflect:
|
||||||
|
result = reflection_mode(user_input, stream)
|
||||||
|
else:
|
||||||
|
result = chat(user_input, stream)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def highlightify_text(full_text):
|
||||||
|
lines = full_text.split('\n')
|
||||||
|
result = ''
|
||||||
|
language = None
|
||||||
|
for line in lines:
|
||||||
|
text = line + '\n'
|
||||||
|
|
||||||
|
# update language if entering or leaving code block
|
||||||
|
if '```' in text:
|
||||||
|
language = text.split('```')[1].split('\n')[0]
|
||||||
|
if language == '':
|
||||||
|
language = None
|
||||||
|
result += text
|
||||||
|
text = ''
|
||||||
|
|
||||||
|
# Only print full lines
|
||||||
|
if '\n' in text:
|
||||||
|
output = text
|
||||||
|
if language:
|
||||||
|
output = highlight_code(language, output)
|
||||||
|
result += output
|
||||||
|
return result
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
# Create the parser
|
# Create the parser
|
||||||
parser = argparse.ArgumentParser(description='Copy and open a source file in TextEdit')
|
parser = argparse.ArgumentParser(description='Copy and open a source file in TextEdit')
|
||||||
|
|
@ -149,11 +179,37 @@ def parse_args():
|
||||||
parser.add_argument('--model', '-m', nargs='?', const=True, default=False, help='Specify model')
|
parser.add_argument('--model', '-m', nargs='?', const=True, default=False, help='Specify model')
|
||||||
# Add the --temp (-t) argument
|
# Add the --temp (-t) argument
|
||||||
parser.add_argument('--temp', '-t', nargs='?', const=True, default=False, help='Specify temperature')
|
parser.add_argument('--temp', '-t', nargs='?', const=True, default=False, help='Specify temperature')
|
||||||
# Add the --host (-h) argument
|
# Add the --host argument
|
||||||
parser.add_argument('--host', nargs='?', const=True, default=False, help='Specify host of ollama server')
|
parser.add_argument('--host', nargs='?', const=True, default=False, help='Specify host of ollama server')
|
||||||
|
# Add the --reflect argument
|
||||||
|
parser.add_argument('--reflect', action='store_true', help='Use reflection prompting style to improve output. May be slower and not work with all models.')
|
||||||
# Parse the arguments
|
# Parse the arguments
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def reflection_mode(query, should_print=False):
|
||||||
|
reflection_prompt = """
|
||||||
|
You are a helpful ai assistant that answers every question thoroughly and accurately. You always begin your response with a <planning></planning> section where you lay out your plan for answering the question. It is important that you don't make any assumptions while planning. Then you <reflect></reflect> on your plan to make sure it correctly answers the user's question. Then, if you are confident your plan in correct, you give your <draft answer>, followed by <final reflection> to make sure the answer correctly addresses the user's question. Finally, give a <final answer> with your answer to the user. If there are any ambiguous or unknown requirements, ask the user for more information as your final answer. You must always have a <final answer> no matter what, even if you are asking for clarifying questions. If you do not have the <final answer> tags, the user will not see your response. Additionally, the user can not see your planning or reflecting, they can only see what goes in the <final answer></final answer> tags, so make sure you provide any information you want to tell the user in there.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
result = chat(reflection_prompt + query, stream=False)
|
||||||
|
highlighted_result = highlightify_text(result)
|
||||||
|
|
||||||
|
# print('==DEBUG==')
|
||||||
|
# print(highlighted_result)
|
||||||
|
# print('==DEBUG==')
|
||||||
|
|
||||||
|
|
||||||
|
final_answer = highlighted_result.split('<final answer>')
|
||||||
|
while len(final_answer) < 2:
|
||||||
|
final_answer = chat('Please put your final answer in <final answer></final answer> tags.', stream=False)
|
||||||
|
final_answer = highlighted_result.split('<final answer>')
|
||||||
|
final_answer = final_answer[1].split('</final answer>')[0]
|
||||||
|
|
||||||
|
if should_print:
|
||||||
|
print(final_answer)
|
||||||
|
return final_answer
|
||||||
|
|
||||||
def set_host(host):
|
def set_host(host):
|
||||||
global server
|
global server
|
||||||
server = host
|
server = host
|
||||||
|
|
@ -172,9 +228,9 @@ def arg_shell(args):
|
||||||
query += args.shell
|
query += args.shell
|
||||||
else:
|
else:
|
||||||
query += input('> ')
|
query += input('> ')
|
||||||
result = chat(query, False)
|
result = chat2(args, query, False)
|
||||||
result = blocks[0] if len(blocks := extract_code_block(result)) else result
|
result = blocks[0] if len(blocks := extract_code_block(result)) else result
|
||||||
print(blocks)
|
print(result)
|
||||||
copy_string_to_clipboard(result)
|
copy_string_to_clipboard(result)
|
||||||
|
|
||||||
def handle_piped_input(args):
|
def handle_piped_input(args):
|
||||||
|
|
@ -186,7 +242,7 @@ def handle_piped_input(args):
|
||||||
query += arg_follow_up(args)
|
query += arg_follow_up(args)
|
||||||
query += '\n'
|
query += '\n'
|
||||||
|
|
||||||
result = chat(query)
|
result = chat2(args, query)
|
||||||
blocks = extract_code_block(result)
|
blocks = extract_code_block(result)
|
||||||
if args.copy and len(blocks):
|
if args.copy and len(blocks):
|
||||||
copy_string_to_clipboard(blocks[0])
|
copy_string_to_clipboard(blocks[0])
|
||||||
|
|
@ -197,7 +253,7 @@ def handle_non_piped_input(args):
|
||||||
exit()
|
exit()
|
||||||
if args.follow_up:
|
if args.follow_up:
|
||||||
user_input = arg_follow_up(args)
|
user_input = arg_follow_up(args)
|
||||||
result = chat(user_input)
|
result = chat2(args, user_input)
|
||||||
exit()
|
exit()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -206,7 +262,7 @@ def handle_non_piped_input(args):
|
||||||
print()
|
print()
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
chat(user_input)
|
result = chat2(args, user_input)
|
||||||
|
|
||||||
client = None
|
client = None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue