Compare commits
No commits in common. "c0d83584c7871bff2b6470cea400bb8468dbf76c" and "161e4dea7751a74da915a61cdb27b286badd407d" have entirely different histories.
c0d83584c7
...
161e4dea77
43
assistant.py
43
assistant.py
|
|
@ -4,9 +4,6 @@ from ollama import Client
|
||||||
import re
|
import re
|
||||||
import pyperclip
|
import pyperclip
|
||||||
import sys
|
import sys
|
||||||
import tty
|
|
||||||
import termios
|
|
||||||
import signal
|
|
||||||
import argparse
|
import argparse
|
||||||
import pygments
|
import pygments
|
||||||
from pygments.lexers import get_lexer_by_name, guess_lexer
|
from pygments.lexers import get_lexer_by_name, guess_lexer
|
||||||
|
|
@ -14,9 +11,6 @@ from pygments.formatters import TerminalFormatter
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from prompt_toolkit.key_binding import KeyBindings
|
|
||||||
from prompt_toolkit import PromptSession
|
|
||||||
|
|
||||||
server = 'localhost:11434'
|
server = 'localhost:11434'
|
||||||
model = 'gemma3:12b'
|
model = 'gemma3:12b'
|
||||||
reasoning_model='deepseek-r1:14b'
|
reasoning_model='deepseek-r1:14b'
|
||||||
|
|
@ -350,24 +344,37 @@ def handle_piped_input(args):
|
||||||
if args.copy and len(blocks):
|
if args.copy and len(blocks):
|
||||||
copy_string_to_clipboard(blocks[0])
|
copy_string_to_clipboard(blocks[0])
|
||||||
|
|
||||||
kb = KeyBindings()
|
def improved_input():
|
||||||
|
|
||||||
@kb.add('c-d')
|
|
||||||
def _(event):
|
|
||||||
event.current_buffer.validate_and_handle()
|
|
||||||
|
|
||||||
session = PromptSession(multiline=True, prompt_continuation='', key_bindings=kb)
|
|
||||||
def improved_input(prompt="> "):
|
|
||||||
"""
|
"""
|
||||||
Returns the full text (including embedded newlines) when you press Ctrl-D.
|
Handles multi-line input and prevents accidental sending of pasted content.
|
||||||
Arrow keys edit within or across lines automatically.
|
Returns the complete input text when the user indicates they're done.
|
||||||
"""
|
"""
|
||||||
|
lines = []
|
||||||
|
print("> ", end='', flush=True)
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
text = session.prompt(prompt)
|
line = sys.stdin.readline()
|
||||||
return text
|
if not line:
|
||||||
|
print("", flush=True) # Some visual feedback that user input has ended
|
||||||
|
break # EOF
|
||||||
|
if line.strip() == '':
|
||||||
|
continue # ignore empty lines
|
||||||
|
lines.append(line)
|
||||||
|
|
||||||
|
# Check for termination signal (Ctrl+D or two consecutive empty lines)
|
||||||
|
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
|
||||||
|
break
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nUser aborted input")
|
print("\nUser aborted input")
|
||||||
return None
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading input: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Join the lines and strip any trailing newlines
|
||||||
|
full_input = ''.join(lines).rstrip('\n')
|
||||||
|
return full_input
|
||||||
|
|
||||||
def handle_non_piped_input(args):
|
def handle_non_piped_input(args):
|
||||||
if args.shell:
|
if args.shell:
|
||||||
|
|
|
||||||
145
refactored.py
145
refactored.py
|
|
@ -1,20 +1,13 @@
|
||||||
#/bin/python
|
|
||||||
# AI Assistant in the terminal
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import json
|
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
import re
|
import re
|
||||||
import pyperclip
|
import pyperclip
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
import pygments
|
import pygments
|
||||||
from pygments.lexers import get_lexer_by_name, guess_lexer
|
from pygments.lexers import get_lexer_by_name, guess_lexer
|
||||||
from pygments.formatters import TerminalFormatter
|
from pygments.formatters import TerminalFormatter
|
||||||
|
import os
|
||||||
from prompt_toolkit.key_binding import KeyBindings
|
import json
|
||||||
from prompt_toolkit import PromptSession
|
|
||||||
|
|
||||||
class AIAssistant:
|
class AIAssistant:
|
||||||
def __init__(self, server="http://localhost:11434", model="gemma3:12b"):
|
def __init__(self, server="http://localhost:11434", model="gemma3:12b"):
|
||||||
|
|
@ -24,6 +17,7 @@ class AIAssistant:
|
||||||
self.temperature = 0.2
|
self.temperature = 0.2
|
||||||
self.num_ctx = 4096
|
self.num_ctx = 4096
|
||||||
self.history = [self.system_prompt()]
|
self.history = [self.system_prompt()]
|
||||||
|
self.load_history()
|
||||||
|
|
||||||
def set_host(self, host):
|
def set_host(self, host):
|
||||||
self.server = host
|
self.server = host
|
||||||
|
|
@ -45,33 +39,6 @@ class AIAssistant:
|
||||||
with open(path, 'w+') as f:
|
with open(path, 'w+') as f:
|
||||||
json.dump(self.history, f)
|
json.dump(self.history, f)
|
||||||
|
|
||||||
def determine_lexer(self, code_block):
|
|
||||||
lexer_name = None
|
|
||||||
lines = code_block.split('\n')
|
|
||||||
for line in lines:
|
|
||||||
if line.strip().startswith('```'):
|
|
||||||
lexer_part = line.strip().split('```')[1].strip()
|
|
||||||
if lexer_part:
|
|
||||||
lexer_name = lexer_part
|
|
||||||
break
|
|
||||||
elif line.strip().startswith('lang:'):
|
|
||||||
lexer_part = line.strip().split(':')[1].strip()
|
|
||||||
if lexer_part:
|
|
||||||
lexer_name = lexer_part
|
|
||||||
break
|
|
||||||
return lexer_name
|
|
||||||
|
|
||||||
def highlight_code(self, lexer_name, code):
|
|
||||||
try:
|
|
||||||
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
|
|
||||||
except ValueError:
|
|
||||||
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
|
|
||||||
if not lexer:
|
|
||||||
lexer = get_lexer_by_name('bash')
|
|
||||||
formatter = TerminalFormatter()
|
|
||||||
highlighted_code = pygments.highlight(code, lexer, formatter)
|
|
||||||
return highlighted_code
|
|
||||||
|
|
||||||
def chat(self, message, stream=True):
|
def chat(self, message, stream=True):
|
||||||
self.history.append({"role": "user", "content": message})
|
self.history.append({"role": "user", "content": message})
|
||||||
completion = self.client.chat(
|
completion = self.client.chat(
|
||||||
|
|
@ -82,28 +49,11 @@ class AIAssistant:
|
||||||
)
|
)
|
||||||
result = ''
|
result = ''
|
||||||
large_chunk = []
|
large_chunk = []
|
||||||
language = None
|
|
||||||
|
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
text = chunk['message']['content']
|
text = chunk['message']['content']
|
||||||
large_chunk.append(text)
|
large_chunk.append(text)
|
||||||
large_text = ''.join(large_chunk)
|
|
||||||
|
|
||||||
if ('```' in large_text) and ('\n' in large_text.split('```')[1]):
|
|
||||||
language = large_text.split('```')[1].split('\n')[0]
|
|
||||||
large_chunk = []
|
|
||||||
if language == '':
|
|
||||||
print(large_text, end='', flush=True)
|
|
||||||
language = None
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
if language and ('\n' in large_text) and large_chunk:
|
|
||||||
output = self.highlight_code(language, large_text)
|
|
||||||
print(output, end='', flush=True)
|
|
||||||
large_chunk = []
|
|
||||||
elif not language or not large_chunk:
|
|
||||||
print(text, end='', flush=True)
|
print(text, end='', flush=True)
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
result = completion['message']['content']
|
result = completion['message']['content']
|
||||||
else:
|
else:
|
||||||
|
|
@ -132,17 +82,9 @@ class CommandLineParser:
|
||||||
def parse(self):
|
def parse(self):
|
||||||
return self.parser.parse_args()
|
return self.parser.parse_args()
|
||||||
|
|
||||||
# Keybindings
|
|
||||||
bindings = KeyBindings()
|
|
||||||
|
|
||||||
@bindings.add('c-d')
|
|
||||||
def _(event):
|
|
||||||
event.current_buffer.validate_and_handle()
|
|
||||||
|
|
||||||
class InputHandler:
|
class InputHandler:
|
||||||
def __init__(self, assistant):
|
def __init__(self, assistant):
|
||||||
self.assistant = assistant
|
self.assistant = assistant
|
||||||
self.session = PromptSession(multiline=True, prompt_continuation='', key_bindings=bindings)
|
|
||||||
|
|
||||||
def handle_input(self, args):
|
def handle_input(self, args):
|
||||||
if not sys.stdin.isatty():
|
if not sys.stdin.isatty():
|
||||||
|
|
@ -156,33 +98,14 @@ class InputHandler:
|
||||||
if args.copy:
|
if args.copy:
|
||||||
query += 'Answer the question using a codeblock for any code or shell scripts\n'
|
query += 'Answer the question using a codeblock for any code or shell scripts\n'
|
||||||
if args.follow_up:
|
if args.follow_up:
|
||||||
second_input = improved_input()
|
second_input = input('> ')
|
||||||
query += f'\n{second_input}'
|
query += f'\n{second_input}'
|
||||||
result = self.assistant.chat(query, stream=False)
|
result = self.assistant.chat(query, stream=False)
|
||||||
blocks = self.extract_code_block(result)
|
blocks = self.extract_code_block(result)
|
||||||
if args.copy and len(blocks):
|
if args.copy and len(blocks):
|
||||||
pyperclip.copy(blocks[0])
|
pyperclip.copy(blocks[0])
|
||||||
|
|
||||||
def arg_shell(args):
|
|
||||||
query = '''
|
|
||||||
Form a shell command based on the following description. Only output a working shell command. Format the command like this: `command`
|
|
||||||
|
|
||||||
Description:
|
|
||||||
'''
|
|
||||||
if args.shell != True:
|
|
||||||
query += args.shell
|
|
||||||
else:
|
|
||||||
query += self.improved_input()
|
|
||||||
result = self.assistant.chat(query, stream=False)
|
|
||||||
result = blocks[0] if len(blocks := self.extract_code_block(result)) else result
|
|
||||||
print(result)
|
|
||||||
copy_string_to_clipboard(result)
|
|
||||||
|
|
||||||
def handle_interactive_input(self, args):
|
def handle_interactive_input(self, args):
|
||||||
if args.shell:
|
|
||||||
self.arg_shell(args)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):")
|
print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -191,23 +114,32 @@ Description:
|
||||||
break
|
break
|
||||||
if full_input.strip() == '':
|
if full_input.strip() == '':
|
||||||
continue
|
continue
|
||||||
|
print("\033[91massistant\033[0m: ", end='')
|
||||||
result = self.assistant.chat(full_input)
|
result = self.assistant.chat(full_input)
|
||||||
print()
|
print()
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
print("\nExiting...")
|
print("\nExiting...")
|
||||||
break
|
break
|
||||||
|
|
||||||
def improved_input(self, prompt="> "):
|
def improved_input(self):
|
||||||
"""
|
lines = []
|
||||||
Returns the full text (including embedded newlines) when you press Ctrl-D.
|
print("> ", end='', flush=True)
|
||||||
Arrow keys edit within or across lines automatically.
|
while True:
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
text = self.session.prompt(prompt)
|
line = sys.stdin.readline()
|
||||||
return text
|
if not line:
|
||||||
|
print(flush=True)
|
||||||
|
break
|
||||||
|
if line.strip() == '':
|
||||||
|
continue
|
||||||
|
lines.append(line)
|
||||||
|
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
|
||||||
|
break
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nUser aborted input")
|
print("\nUser aborted input")
|
||||||
return None
|
return None
|
||||||
|
full_input = ''.join(lines).rstrip('\n')
|
||||||
|
return full_input
|
||||||
|
|
||||||
def extract_code_block(self, text):
|
def extract_code_block(self, text):
|
||||||
pattern = r'```[a-z]*\n[\s\S]*?\n```'
|
pattern = r'```[a-z]*\n[\s\S]*?\n```'
|
||||||
|
|
@ -215,8 +147,8 @@ Description:
|
||||||
matches = re.finditer(pattern, text)
|
matches = re.finditer(pattern, text)
|
||||||
for match in matches:
|
for match in matches:
|
||||||
code_block = match.group(0)
|
code_block = match.group(0)
|
||||||
lexer_name = self.assistant.determine_lexer(code_block)
|
lexer_name = self.determine_lexer(code_block)
|
||||||
highlighted_code = self.assistant.highlight_code(lexer_name, code_block)
|
highlighted_code = self.highlight_code(lexer_name, code_block)
|
||||||
code_blocks.append(highlighted_code)
|
code_blocks.append(highlighted_code)
|
||||||
if not code_blocks:
|
if not code_blocks:
|
||||||
line_pattern = r'`[a-z]*[\s\S]*?`'
|
line_pattern = r'`[a-z]*[\s\S]*?`'
|
||||||
|
|
@ -226,6 +158,33 @@ Description:
|
||||||
code_blocks.append(code_block[1:-1])
|
code_blocks.append(code_block[1:-1])
|
||||||
return code_blocks
|
return code_blocks
|
||||||
|
|
||||||
|
def determine_lexer(self, code_block):
|
||||||
|
lexer_name = None
|
||||||
|
lines = code_block.split('\n')
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().startswith('```'):
|
||||||
|
lexer_part = line.strip().split('```')[1].strip()
|
||||||
|
if lexer_part:
|
||||||
|
lexer_name = lexer_part
|
||||||
|
break
|
||||||
|
elif line.strip().startswith('lang:'):
|
||||||
|
lexer_part = line.strip().split(':')[1].strip()
|
||||||
|
if lexer_part:
|
||||||
|
lexer_name = lexer_part
|
||||||
|
break
|
||||||
|
return lexer_name
|
||||||
|
|
||||||
|
def highlight_code(self, lexer_name, code):
|
||||||
|
try:
|
||||||
|
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
|
||||||
|
except ValueError:
|
||||||
|
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
|
||||||
|
if not lexer:
|
||||||
|
lexer = get_lexer_by_name('bash')
|
||||||
|
formatter = TerminalFormatter()
|
||||||
|
highlighted_code = pygments.highlight(code, lexer, formatter)
|
||||||
|
return highlighted_code
|
||||||
|
|
||||||
class CommandParser:
|
class CommandParser:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.commands = {
|
self.commands = {
|
||||||
|
|
@ -293,8 +252,6 @@ def main():
|
||||||
if args.new:
|
if args.new:
|
||||||
assistant.history = [assistant.system_prompt()]
|
assistant.history = [assistant.system_prompt()]
|
||||||
assistant.save_history()
|
assistant.save_history()
|
||||||
else:
|
|
||||||
assistant.load_history()
|
|
||||||
|
|
||||||
input_handler = InputHandler(assistant)
|
input_handler = InputHandler(assistant)
|
||||||
input_handler.handle_input(args)
|
input_handler.handle_input(args)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue