This commit introduces several improvements to the input handling process: - Implemented `improved_input()` to handle multi-line input and prevent accidental sending of pasted content. The user signals completion by pressing Ctrl+D or entering two consecutive empty lines. - Added a new command-line argument `--context` to allow users to specify the context length for the model. - Updated the help message to reflect the new argument.
425 lines
15 KiB
Python
Executable file
425 lines
15 KiB
Python
Executable file
#!/bin/python
|
|
# Chat with an intelligent assistant in your terminal
|
|
from ollama import Client
|
|
import re
|
|
import pyperclip
|
|
import sys
|
|
import argparse
|
|
import pygments
|
|
from pygments.lexers import get_lexer_by_name, guess_lexer
|
|
from pygments.formatters import TerminalFormatter
|
|
import os
|
|
import json
|
|
|
|
server = 'localhost:11434'
|
|
model = 'gemma3:12b'
|
|
reasoning_model='deepseek-r1:14b'
|
|
temp = 0.2
|
|
num_ctx = 4096
|
|
|
|
|
|
pattern = r'```[a-z]*\n[\s\S]*?\n```'
|
|
line_pattern = r'`[a-z]*[\s\S]*?`'
|
|
|
|
history_path = os.environ.get('HOME') + '/.cache/ai-assistant.history'
|
|
|
|
def save_history(data, path):
|
|
with open(path, 'w+') as f:
|
|
json.dump(data, f)
|
|
|
|
def load_history(path):
|
|
with open(path, 'r') as f:
|
|
return json.load(f)
|
|
|
|
def save_conversation(filename='conversation.md'):
|
|
# check if filename already exists and increment filename if so
|
|
if not filename.endswith('.md'):
|
|
filename += '.md'
|
|
|
|
base, extension = os.path.splitext(filename)
|
|
i = 1
|
|
while os.path.exists(filename):
|
|
filename = f"{base}_{i}{extension}"
|
|
i += 1
|
|
# save conversation to filename
|
|
global conversation
|
|
with open(filename, 'w') as f:
|
|
f.write(conversation)
|
|
|
|
def parse_commands(text):
|
|
# See if user wrote any commands here
|
|
# returns bool: True if command was executed, False if not
|
|
# importantly, the command doesn't need to execute succesfully for it to return True
|
|
tokens = text.split(' ')
|
|
match tokens[0]:
|
|
case '/save':
|
|
if len(tokens) > 1:
|
|
save_conversation(tokens[1])
|
|
else:
|
|
save_conversation()
|
|
return True
|
|
case '/clear':
|
|
global history
|
|
history = [ system_prompt ]
|
|
save_history(history, history_path)
|
|
return True
|
|
case '/clipboard':
|
|
context_query = '\n\nThe following is context provided by the user:\n'
|
|
context_query += get_string_from_clipboard() + '\n'
|
|
return text.split('/clipboard ')[1] + context_query
|
|
case '/exit':
|
|
exit()
|
|
return False
|
|
|
|
|
|
def highlight_code(language_name, code):
|
|
# Check if the language is specified in the first line
|
|
lexer_name = language_name
|
|
if lexer_name == None:
|
|
lines = code.split('\n')
|
|
for line in lines:
|
|
if line.strip().startswith('```'):
|
|
lexer_name = line.strip().split('```')[1].strip()
|
|
break
|
|
elif line.strip().startswith('lang:'):
|
|
lexer_name = line.strip().split(':')[1].strip()
|
|
break
|
|
|
|
if lexer_name:
|
|
try:
|
|
# Try to get the lexer by name
|
|
lexer = get_lexer_by_name(lexer_name)
|
|
except ValueError:
|
|
# If the lexer is not found, guess it
|
|
lexer = guess_lexer(code.split('\n')[1:-1])
|
|
if not lexer:
|
|
# If no lexer is guessed, default to bash
|
|
lexer = get_lexer_by_name('bash')
|
|
else:
|
|
# If no language is specified, guess the lexer
|
|
print("LEXER NAME " + lexer_name)
|
|
lexer = guess_lexer(code.split('\n')[1:-1])
|
|
if not lexer:
|
|
# If no lexer is guessed, default to bash
|
|
lexer = get_lexer_by_name('bash')
|
|
|
|
formatter = TerminalFormatter()
|
|
|
|
just_code = code.split('\n')[0]
|
|
newlines = '\n'.join(code.split('\n')[1:])
|
|
# if code is a code block, strip surrounding block markers
|
|
lines = code.split('\n')
|
|
if (len(lines) > 2) and ('```' in lines[0]) and ('```' in lines[-1]):
|
|
just_code = '\n'.join(code.split('\n')[1:-1])
|
|
|
|
highlighted_code = pygments.highlight(just_code, lexer, formatter)
|
|
return highlighted_code + newlines
|
|
|
|
def extract_code_block(markdown_text):
|
|
# Use the regular expression pattern to find all matches in the markdown text
|
|
matches = re.finditer(pattern, markdown_text)
|
|
|
|
# Iterate over the matches and extract the code blocks
|
|
code_blocks = []
|
|
for match in matches:
|
|
code_block = match.group(0)
|
|
|
|
highlighted_code = highlight_code(None, code_block)
|
|
|
|
# Add the highlighted code block to the list of code blocks
|
|
code_blocks.append(highlighted_code)
|
|
|
|
|
|
if len(code_blocks) == 0:
|
|
line_matches = re.finditer(line_pattern, markdown_text)
|
|
for match in line_matches:
|
|
code_block = match.group(0)
|
|
code_blocks.append(code_block[1:-1])
|
|
return code_blocks
|
|
|
|
def copy_string_to_clipboard(string):
|
|
try:
|
|
pyperclip.copy(string)
|
|
except:
|
|
return
|
|
|
|
def get_string_from_clipboard():
|
|
try:
|
|
result = pyperclip.paste()
|
|
except:
|
|
result = ''
|
|
return result
|
|
|
|
code_history = []
|
|
|
|
system_prompt = {"role": "system", "content": "You are a helpful, smart, kind, and efficient AI assistant. You always fulfill the user's requests accurately and concisely."}
|
|
|
|
history = [ system_prompt ]
|
|
|
|
conversation = ""
|
|
|
|
def chat(message, stream=True):
|
|
history.append({"role": "user", "content": message})
|
|
completion = client.chat(
|
|
model=model,
|
|
options={"temperature":temp, "num_ctx":num_ctx},
|
|
messages=history,
|
|
stream=stream
|
|
)
|
|
result = ''
|
|
language = ''
|
|
large_chunk = []
|
|
for chunk in completion:
|
|
if stream:
|
|
text = chunk['message']['content']
|
|
large_chunk.append(text)
|
|
large_text = ''.join(large_chunk)
|
|
|
|
# update language if entering or leaving code block
|
|
if ('\n' in large_text) and ('```' in large_text):
|
|
language = large_text.split('```')[1].split('\n')[0]
|
|
if language == '':
|
|
language = None
|
|
print(large_text, end='', flush=True)
|
|
large_chunk = []
|
|
large_text = ''
|
|
|
|
# Only print full lines
|
|
if '\n' in large_text:
|
|
output = large_text
|
|
if language:
|
|
output = highlight_code(language, output)
|
|
print(output, end='', flush=True)
|
|
large_chunk = []
|
|
result += text
|
|
if not stream:
|
|
result = completion['message']['content']
|
|
if stream:
|
|
print(large_text, flush=True)
|
|
history.append({"role": 'assistant', 'content': result})
|
|
return result
|
|
|
|
def chat2(args, user_input, stream=True):
|
|
global conversation
|
|
global model
|
|
global reasoning_model
|
|
command_result = parse_commands(user_input)
|
|
if command_result:
|
|
if type(command_result) == bool:
|
|
return ''
|
|
elif type(command_result) == str: # sometimes I want to change the user prompt with a command
|
|
user_input = command_result
|
|
|
|
print('\033[91m' + 'assistant' + '\033[0m: ', end='')
|
|
if args.reasoning:
|
|
model = reasoning_model
|
|
result = chat(user_input, stream)
|
|
else:
|
|
result = chat(user_input, stream)
|
|
|
|
conversation += 'user: ' + user_input + '\n'
|
|
conversation += 'assistant: ' + result + '\n'
|
|
return result
|
|
|
|
def highlightify_text(full_text):
|
|
lines = full_text.split('\n')
|
|
result = ''
|
|
language = None
|
|
for line in lines:
|
|
text = line + '\n'
|
|
|
|
# update language if entering or leaving code block
|
|
if '```' in text:
|
|
language = text.split('```')[1].split('\n')[0]
|
|
if language == '':
|
|
language = None
|
|
result += text
|
|
text = ''
|
|
|
|
# Only print full lines
|
|
if '\n' in text:
|
|
output = text
|
|
if language:
|
|
output = highlight_code(language, output)
|
|
result += output
|
|
return result
|
|
|
|
def parse_args():
|
|
# Create the parser
|
|
parser = argparse.ArgumentParser(description='Copy and open a source file in TextEdit')
|
|
# Add the --follow-up (-f) argument
|
|
parser.add_argument('--follow-up', '-f', nargs='?', const=True, default=False, help='Ask a follow up question when piping in context')
|
|
# Add the --copy (-c) argument
|
|
parser.add_argument('--copy', '-c', action='store_true', help='copy a codeblock if it appears')
|
|
# Add the --shell (-s) argument
|
|
parser.add_argument('--shell', '-s', nargs='?', const=True, default=False, help='output a shell command that does as described')
|
|
# Add the --model (-m) argument
|
|
parser.add_argument('--model', '-m', nargs='?', const=True, default=False, help='Specify model')
|
|
# Add the --temp (-t) argument
|
|
parser.add_argument('--temp', '-t', nargs='?', const=True, default=False, help='Specify temperature')
|
|
# Add the --context
|
|
parser.add_argument('--context', nargs='?', const=True, default=False, help='Specify temperature')
|
|
# Add the --host argument
|
|
parser.add_argument('--host', nargs='?', const=True, default=False, help='Specify host of ollama server')
|
|
# Add the --reflect argument
|
|
parser.add_argument('--reasoning', '-r', action='store_true', help='Use the default reasoning model deepseek-r1:14b')
|
|
# Add the --new argument
|
|
parser.add_argument('--new', '-n', action='store_true', help='Start a chat with a fresh history')
|
|
# Parse the arguments
|
|
return parser.parse_args()
|
|
|
|
def reflection_mode(query, should_print=False):
|
|
reflection_prompt = """
|
|
You are a helpful ai assistant that answers every question thoroughly and accurately. You always begin your response with a <planning></planning> section where you lay out your plan for answering the question. It is important that you don't make any assumptions while planning. Then you <reflect></reflect> on your plan to make sure it correctly answers the user's question. Then, if you are confident your plan in correct, you give your <draft answer>, followed by <final reflection> to make sure the answer correctly addresses the user's question. Finally, give a <final answer> with your answer to the user. If there are any ambiguous or unknown requirements, ask the user for more information as your final answer. You must always have a <final answer> no matter what, even if you are asking for clarifying questions. If you do not have the <final answer> tags, the user will not see your response. Additionally, the user can not see your planning or reflecting, they can only see what goes in the <final answer></final answer> tags, so make sure you provide any information you want to tell the user in there.
|
|
|
|
|
|
"""
|
|
result = chat(reflection_prompt + query, stream=False)
|
|
highlighted_result = highlightify_text(result)
|
|
|
|
# print('==DEBUG==')
|
|
# print(highlighted_result)
|
|
# print('==DEBUG==')
|
|
|
|
|
|
final_answer = highlighted_result.split('<final answer>')
|
|
while len(final_answer) < 2:
|
|
final_answer = chat('Please put your final answer in <final answer></final answer> tags.', stream=False)
|
|
final_answer = highlighted_result.split('<final answer>')
|
|
final_answer = final_answer[1].split('</final answer>')[0]
|
|
|
|
if should_print:
|
|
print(final_answer)
|
|
return final_answer
|
|
|
|
def set_host(host):
|
|
global server
|
|
server = host
|
|
|
|
def arg_follow_up(args):
|
|
sys.stdin = open('/dev/tty')
|
|
if args.follow_up != True:
|
|
second_input = args.follow_up
|
|
else:
|
|
second_input = input('> ')
|
|
return second_input
|
|
|
|
def arg_shell(args):
|
|
query = '''
|
|
Form a shell command based on the following description. Only output a working shell command. Format the command like this: `command`
|
|
|
|
Description:
|
|
'''
|
|
if args.shell != True:
|
|
query += args.shell
|
|
else:
|
|
query += input('> ')
|
|
result = chat2(args, query, False)
|
|
result = blocks[0] if len(blocks := extract_code_block(result)) else result
|
|
print(result)
|
|
copy_string_to_clipboard(result)
|
|
|
|
def handle_piped_input(args):
|
|
all_input = sys.stdin.read()
|
|
query = 'Use the following context to answer the question. There will be no follow up questions from the user so make sure your answer is complete:\nSTART CONTEXT\n' + all_input + '\nEND CONTEXT\nAfter you answer the question, reflect on your answer and determine if it answers the question correctly.'
|
|
if args.copy:
|
|
query += 'Answer the question using a codeblock for any code or shell scripts\n'
|
|
if args.follow_up:
|
|
query += arg_follow_up(args)
|
|
query += '\n'
|
|
|
|
result = chat2(args, query)
|
|
blocks = extract_code_block(result)
|
|
if args.copy and len(blocks):
|
|
copy_string_to_clipboard(blocks[0])
|
|
|
|
def improved_input():
|
|
"""
|
|
Handles multi-line input and prevents accidental sending of pasted content.
|
|
Returns the complete input text when the user indicates they're done.
|
|
"""
|
|
lines = []
|
|
while True:
|
|
try:
|
|
line = sys.stdin.readline()
|
|
if not line:
|
|
break # EOF
|
|
if line.strip() == '':
|
|
continue # ignore empty lines
|
|
lines.append(line)
|
|
|
|
# Check for termination signal (Ctrl+D or two consecutive empty lines)
|
|
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
|
|
break
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nUser aborted input")
|
|
return None
|
|
except Exception as e:
|
|
print(f"Error reading input: {e}")
|
|
return None
|
|
|
|
# Join the lines and strip any trailing newlines
|
|
full_input = ''.join(lines).rstrip('\n')
|
|
return full_input
|
|
|
|
def handle_non_piped_input(args):
|
|
if args.shell:
|
|
arg_shell(args)
|
|
exit()
|
|
if args.follow_up:
|
|
user_input = arg_follow_up(args)
|
|
result = chat2(args, user_input)
|
|
exit()
|
|
|
|
global history
|
|
history = load_history(history_path)
|
|
|
|
print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):")
|
|
while True:
|
|
try:
|
|
full_input = improved_input()
|
|
if full_input is None:
|
|
break # User aborted
|
|
if full_input.strip() == '':
|
|
continue # Skip empty messages
|
|
|
|
result = chat2(args, full_input)
|
|
save_history(history, history_path)
|
|
|
|
except (EOFError, KeyboardInterrupt):
|
|
print("\nExiting...")
|
|
break
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
client = None
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if args.host:
|
|
set_host(args.host)
|
|
# Point to the local server
|
|
global client
|
|
client = Client(host=server)
|
|
if args.model:
|
|
global model
|
|
model = args.model
|
|
if args.temp:
|
|
global temp
|
|
temp = float(args.temp)
|
|
if args.context:
|
|
global num_ctx
|
|
num_ctx = float(args.context)
|
|
if args.new:
|
|
global history
|
|
history = [system_prompt]
|
|
save_history(history, history_path)
|
|
if not sys.stdin.isatty():
|
|
handle_piped_input(args)
|
|
else:
|
|
handle_non_piped_input(args)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|