Welcome to my blog! Today, I’m excited to share a neat Python script that automates the generation of docstrings for your Python files. This tool is especially useful for developers who want to maintain clean and well-documented code without spending too much time writing documentation manually.
The script utilizes OpenAI’s API to generate meaningful docstrings based on the code provided. It analyzes the structure of your Python classes and functions, and then creates concise and informative docstrings that follow standard conventions. Let’s dive into the code!
#!/usr/bin/python
import argparse
import hashlib
import json
import os
import re
import statistics
import subprocess
import sys
import tempfile
import time
import requests
system_prompt_general = """
You will be provided with a Python code.
Based on this, your task is to generate a Python docstring for it.
This is the docstring that goes at the top of the file.
The file may include multiple classes or other code.
Ensure the docstring follows Python's standard docstring conventions and provides
just enough detail to make the file understandable and usable without overwhelming the reader.
Please only return the docstring, enclosed in triple quotes, without any other explanation
or additional text. The format should be:
\"\"\"
\"\"\"
Make sure to follow the format precisely and provide only the docstring content.
"""
system_prompt_class = """
You will be provided with a Python class, including its code.
Based on this, your task is to generate a Python docstring for it.
Use the class signature and body to infer the purpose of the class,
the attributes it has, and any methods it includes.
Follow these guidelines to create the docstring:
1. Summary: Provide a concise summary of the class's purpose.
Focus on what the class does and its main goal.
2. Attributes: List the attributes, their types, and a brief description
of what each one represents.
Ensure the docstring follows Python's standard docstring conventions and provides
just enough detail to make the class understandable and usable without overwhelming the reader.
Please only return the docstring, enclosed in triple quotes, without any other
explanation or additional text. The format should be:
\"\"\"
\"\"\"
Make sure to follow the format precisely and provide only the docstring content.
"""
system_prompt_def = """
You will be provided with a Python function, including its code.
Based on this, your task is to generate a Python docstring for it.
Use the function signature and body to infer the purpose of the function,
the arguments it takes, the return value, and any exceptions it may raise.
Follow these guidelines to create the docstring:
1. Summary: Provide a concise summary of the function's purpose.
Focus on what the function does and its main goal.
2. Arguments: List the parameters, their types, and a brief description
of what each one represents.
3. Return: If the function has a return value, describe the return type
and what it represents. If there's no return, OMIT THE SECTION.
4. Exceptions: If the function raises any exceptions, list them with descriptions.
If no exceptions are raised, OMIT THE SECTION.
5. Side Effects (if applicable): If the function has side effects
(e.g., modifies global state, interacts with external services), mention them.
OMIT THE SECTION if it is not clear in the code.
6. Algorithm or Key Logic (optional): If the function is complex,
provide a high-level outline of the logic or algorithm involved.
OMIT THE SECTION if it is not clear in the code.
Ensure the docstring follows Python's standard docstring conventions and provides
just enough detail to make the function understandable and usable without overwhelming the reader.
Please only return the docstring, enclosed in triple quotes, without any other
explanation or additional text. The format should be:
\"\"\"
\"\"\"
Make sure to follow the format precisely and provide only the docstring content.
"""
class OpenAICost:
# Static member to track the total cost
cost = 0
costs = [] # A list to store the individual request costs
@staticmethod
def send_cost(tokens, model):
# The cost calculation can vary based on the model and tokens
model_costs = {
"gpt-3.5-turbo": 0.002, # Example cost per 1k tokens
"gpt-4o-mini": 0.003, # Example cost per 1k tokens
"gpt-4o": 0.005, # Example cost per 1k tokens
}
cost_per_token = model_costs.get(model, 0)
cost = (tokens / 1000) * cost_per_token # Cost is proportional to tokens
OpenAICost.cost += cost
OpenAICost.costs.append(cost)
@staticmethod
def print_cost_metrics():
print(f"\nTotal Cost: ${OpenAICost.cost:.4f}")
if OpenAICost.costs:
print(
f" Average Cost per Request: ${statistics.mean(OpenAICost.costs):.4f}"
)
print(f" Max Cost for a Request: ${max(OpenAICost.costs):.4f}")
print(f" Min Cost for a Request: ${min(OpenAICost.costs):.4f}")
if len(OpenAICost.costs) > 1:
print(
f" Standard Deviation of Cost: ${statistics.stdev(OpenAICost.costs):.4f}"
)
else:
print(" No costs recorded.")
class OpenAIProvider:
def __init__(self):
self.api_key = os.getenv("JIRA_AI_API_KEY")
if not self.api_key:
raise EnvironmentError("JIRA_AI_API_KEY not set in environment.")
self.endpoint = "https://api.openai.com/v1/chat/completions"
self.model = os.getenv("OPENJIRA_AI_MODEL", "gpt-4")
def estimate_tokens(self, text: str) -> int:
tokens = len(text) // 3 # Using the 3 bytes per token estimation
return tokens
def select_model(self, input):
tokens = self.estimate_tokens(input)
if tokens < 1000: # For small files (under ~1000 tokens)
model = "gpt-3.5-turbo"
elif tokens < 10000: # For medium files (under ~10000 tokens)
model = "gpt-4o-mini"
else: # For large files (over ~10000 tokens)
model = "gpt-4o"
OpenAICost.send_cost(tokens, model)
return model
def improve_text(self, prompt: str, text: str) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
body = {
"model": self.select_model(text),
"messages": [
{"role": "system", "content": prompt},
{"role": "user", "content": text},
],
"temperature": 0.5,
}
response = requests.post(self.endpoint, json=body, headers=headers, timeout=300)
if response.status_code == 200:
res = response.json()["choices"][0]["message"]["content"].strip()
result = res
# Count the occurrences of '"""'
occurrences = result.count('"""')
occurrences_backtick = result.count("```")
# Check if there are more than two occurrences
if occurrences > 2:
# Find the positions of the first and second occurrences
first_pos = result.find('"""')
second_pos = result.find('"""', first_pos + 1)
# Get everything from the first '"""' to the second '"""', inclusive
result = result[first_pos : second_pos + 3] # Include the second '"""'
if occurrences_backtick > 2:
# Find the positions of the first and second occurrences
first_pos = result.find("```")
second_pos = result.find("```", first_pos + 1)
# Get everything from the first '"""' to the second '"""', inclusive
result = result[first_pos : second_pos + 3] # Include the second '"""'
return result
raise Exception(
f"OpenAI API call failed: {response.status_code} - {response.text}"
)
class Docstring:
def __init__(self, file_path, debug=False, exit=False):
self.file_path = file_path
self.lines = []
self.ai = OpenAIProvider()
self.line_index = 0
self.multiline_index = 0
self.cache_file = "docstring.cache"
self.debug = debug
self.exit = exit
with open(self.file_path, "r") as file:
self.lines = file.readlines()
# Ensure cache file exists, create if necessary
if not os.path.exists(self.cache_file):
with open(self.cache_file, "w") as cache:
json.dump({}, cache)
self._load_cache()
print(" -> " + self.file_path)
def print_debug(self, title, out):
if not self.debug:
return
print("=====================================================")
print("=====================================================")
print(" > > " + title)
print("=====================================================")
out = "".join(out) if isinstance(out, list) else out
print(out)
print("=====================================================")
print("=====================================================")
def _load_cache(self):
with open(self.cache_file, "r") as cache:
self.cache = json.load(cache)
def _save_cache(self):
with open(self.cache_file, "w") as cache_file:
json.dump(self.cache, cache_file, indent=4)
def _generate_sha1(self, user_prompt):
return hashlib.sha1(user_prompt.encode("utf-8")).hexdigest()
def _get_current_timestamp(self):
return int(time.time())
def get_ai_docstring(self, sys_prompt, user_prompt, signiture):
sha1_hash = self._generate_sha1(user_prompt)
# Check if file is in cache
if self.file_path in self.cache:
# Check if the user prompt's SHA1 is in the self.cache for this file
for entry in self.cache[self.file_path]:
if entry["sha1"] == sha1_hash:
# Update last_accessed timestamp
entry["last_accessed"] = self._get_current_timestamp()
# Return cached docstring if found
return entry["docstring"]
print(" Requesting AI for: " + signiture)
# If no self.cache hit, call the AI and get the docstring
res = self.ai.improve_text(sys_prompt, user_prompt)
# Create a new self.cache entry with last_accessed timestamp
new_entry = {
"sha1": sha1_hash,
"docstring": res,
"last_accessed": self._get_current_timestamp(),
}
# Add new entry to the self.cache for the current file
if self.file_path not in self.cache:
self.cache[self.file_path] = []
self.cache[self.file_path].append(new_entry)
# Return the new docstring from AI
return res
def remove_old_entries(self, minutes):
current_timestamp = self._get_current_timestamp()
threshold_timestamp = current_timestamp - (minutes * 60)
# Remove old entries for each file in the self.cache
for file_path, entries in self.cache.items():
self.cache[file_path] = [
entry
for entry in entries
if "last_accessed" in entry
and entry["last_accessed"] >= threshold_timestamp
]
def wrap_text(self, text: str, max_length=120, indent=0):
wrapped_lines = []
lines = text.strip().splitlines()
# Ensure indent is an integer
try:
indent = int(indent)
except ValueError:
indent = 0 # Default to 0 if it’s an invalid value
spacer = ""
if indent == 0:
spacer = ""
else:
spacer = " " * (indent * 4)
for line in lines:
line = spacer + line.strip()
while len(line) > max_length:
# Find last space to split at
split_at = line.rfind(" ", 0, max_length)
if split_at == -1:
split_at = max_length # no space found, split at max_length
wrapped_lines.append(line[:split_at].rstrip())
line = spacer + line[split_at:].strip()
wrapped_lines.append(line)
return wrapped_lines
def count_and_divide_whitespace(self, line):
leading_whitespace = len(line) - len(line.lstrip())
if leading_whitespace == 0:
return 0
return leading_whitespace // 4
def complete(self):
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file_path = temp_file.name
temp_file.write(
"".join(self.lines).encode()
) # Write the content to the temporary file
try:
# Attempt to compile the temporary file
result = subprocess.run(
["python", "-m", "py_compile", temp_file_path],
capture_output=True,
text=True,
)
# If there is no compilation error (i.e., result.returncode == 0), move the file to the destination
if result.returncode == 0:
with open(self.file_path, "w") as file:
print(f" Wrote: {self.file_path}")
file.write("".join(self.lines)) # Write to the destination file
else:
print(f" Error compiling file: {result.stderr}")
if self.debug:
name = "/tmp/" + os.path.basename(self.file_path) + ".failed"
with open(name, "w") as file:
file.write("".join(self.lines))
print(f" Copied here: {name}")
if self.exit:
sys.exit(1)
finally:
# Clean up the temporary file
os.remove(temp_file_path)
self.remove_old_entries(1440 * 14)
self._save_cache()
def generate_class_docstring(self):
line = self.lines[self.line_index]
class_definition = line
output = re.sub("\\s+", " ", class_definition.rstrip().replace("\n", " "))
print(" -> " + output)
prompt_class_code = [class_definition]
if self.count_and_divide_whitespace(class_definition) > 0:
self.line_index = self.line_index + 1
return None
t = self.line_index + 1
# Collect all lines that belong to the class, including "pass" or single-line classes
while (
t < len(self.lines)
and not self.lines[t].startswith("def")
and not self.lines[t].startswith("class")
):
prompt_class_code.append(self.lines[t].rstrip())
t += 1
class_docstring = self.get_ai_docstring(
system_prompt_class, "\n".join(prompt_class_code), output
)
class_docstring = self.wrap_text(class_docstring, max_length=120, indent=1)
class_docstring[len(class_docstring) - 1] = (
class_docstring[len(class_docstring) - 1] + "\n"
)
class_docstring = [line + "\n" for line in class_docstring]
class_docstring[len(class_docstring) - 1] = (
class_docstring[len(class_docstring) - 1].rstrip() + "\n"
)
# Check for existing docstring and replace it
docstring_start_index = None
docstring_end_index = None
# Look for the class docstring (the second line should start with """ if it's there)
if self.lines and self.lines[self.line_index + 1].strip().startswith('"""'):
# Docstring exists, find the end of it
docstring_start_index = (
self.line_index + 1
) # The docstring starts from line after the class definition
for i, line in enumerate(
self.lines[self.line_index + 2 :], start=self.line_index + 2
):
if line.strip().startswith('"""'):
docstring_end_index = i # End of the docstring
break
# If a docstring exists, replace it
if docstring_start_index is not None and docstring_end_index is not None:
self.lines = (
self.lines[:docstring_start_index]
+ self.lines[docstring_end_index + 1 :]
)
# Insert the new docstring after the class definition
self.lines = (
self.lines[: self.line_index + 1]
+ class_docstring
+ self.lines[self.line_index + 1 :]
)
# self.print_debug("class docstring: " + class_definition.strip(), self.lines)
return True
def generate_function_docstring(self):
line = self.lines[self.line_index]
mutliline_line = ""
if not (
line.strip().endswith("):")
and not re.search(r"\)\s*->\s*(.*)\s*:.*", line.strip())
):
self.multiline_index = 0
# multiline def signiture
while self.line_index < len(self.lines):
mutliline_line += self.lines[self.line_index]
if re.match(
r".*\):$", self.lines[self.line_index].strip()
) or re.search(
r".*\)\s*->\s*(.*)\s*:.*", self.lines[self.line_index].strip()
):
break
self.line_index += 1
self.multiline_index += 1
def_definition = line if mutliline_line == "" else mutliline_line
output = re.sub("\\s+", " ", def_definition.rstrip().replace("\n", " "))
print(" -> " + output)
prompt_def_code = (
mutliline_line.split("\n") if mutliline_line != "" else [def_definition]
)
indent_line = self.count_and_divide_whitespace(
def_definition if mutliline_line == "" else mutliline_line.splitlines()[0]
)
spacer_line = "" if indent_line == 0 else " " * (indent_line * 4)
spacer_line_minus = "" if indent_line < 2 else " " * ((indent_line - 1) * 4)
spacer_line_plus = "" if indent_line == 0 else " " * ((indent_line + 1) * 4)
t = self.line_index + 1
# Collect all self.lines that belong to the function
while t < len(self.lines):
starts_with_def = self.lines[t].strip().startswith("def")
# same indent
if starts_with_def and self.lines[t].startswith(spacer_line):
break
# outside
if starts_with_def and self.lines[t].startswith(spacer_line_minus):
break
# nested
if starts_with_def and self.lines[t].startswith(spacer_line_plus):
pass
if self.lines[t].rstrip() != def_definition.strip():
prompt_def_code.append(self.lines[t].rstrip())
t += 1
# Now that we have the full function signature, we generate the docstring
indent = (
self.count_and_divide_whitespace(
def_definition
if mutliline_line == ""
else mutliline_line.splitlines()[0]
)
+ 1
)
def_docstring = self.get_ai_docstring(
system_prompt_def, "\n".join(prompt_def_code), output
)
def_docstring = self.wrap_text(def_docstring, max_length=120, indent=indent)
def_docstring[len(def_docstring) - 1] = (
def_docstring[len(def_docstring) - 1] + "\n"
)
if def_definition.strip() == def_docstring[0].strip():
def_docstring = def_docstring[
1 if mutliline_line == "" else self.multiline_index :
]
def_docstring = [line + "\n" for line in def_docstring]
def_docstring[len(def_docstring) - 1] = (
def_docstring[len(def_docstring) - 1].rstrip() + "\n"
)
# Handle one-liner docstring or multi-line docstring
if '"""' in self.lines[self.line_index + 1]:
stripped_line = self.lines[self.line_index + 1].strip()
if re.match(r'"""[\s\S]+?"""', stripped_line):
# This is a one-liner or multi-line docstring (we always replace with a multi-line docstring)
self.lines = (
self.lines[: self.line_index + 1]
+ def_docstring
+ self.lines[self.line_index + 2 :]
)
else:
# Replace the entire docstring if it's multi-line
end_index = self.line_index + 2
while end_index < len(self.lines) and not self.lines[
end_index
].strip().startswith('"""'):
end_index += 1
if (
end_index < len(self.lines)
and self.lines[end_index].strip() == '"""'
):
# Found the end of the docstring, now replace the entire docstring block
self.lines = (
self.lines[: self.line_index + 1]
+ def_docstring
+ self.lines[end_index + 1 :]
)
else:
# If no docstring exists, simply insert the generated docstring
self.lines = (
self.lines[: self.line_index + 1]
+ def_docstring
+ self.lines[self.line_index + 1 :]
)
# self.print_debug("def docstring: " + def_definition.strip(), self.lines)
self.line_index = self.line_index + len(def_docstring)
return True
def generate_file_docstring(self):
# Check if we should add a file-level docstring
# if not self.should_add_file_docstring():
# return 0 # Skip generating file-level docstring if not needed
shebang = ""
# Check if the first line starts with a shebang (e.g., #! anything)
if self.lines and not self.lines[0].startswith("#!"):
self.lines = ["#!/usr/bin/env python\n"] + self.lines
shebang = "#!/usr/bin/env python\n"
else:
shebang = self.lines[0]
# Check if there's already an existing file-level docstring or comment block
# We assume the file-level docstring starts with triple quotes (""" or ''') and is at the top
docstring_start_index = None
docstring_end_index = None
if self.lines and self.lines[1].strip().startswith('"""'):
# If the second line starts with triple quotes, it may be a docstring
docstring_start_index = 1 # The docstring starts from line 2
for i, line in enumerate(self.lines[2:], start=2):
if line.strip().startswith('"""'):
docstring_end_index = i # End of the docstring
break
# Generate new file-level docstring
general_description = self.get_ai_docstring(
system_prompt_general, "".join(self.lines), self.file_path
)
general_description = self.wrap_text(
general_description, max_length=120, indent=0
)
docstring = [line + "\n" for line in general_description]
# self.print_debug("docstring_end_index", str(docstring_end_index))
# self.print_debug("self.lines[:docstring_start_index]", self.lines[:docstring_start_index])
# self.print_debug("docstring", docstring)
# self.print_debug("self.lines[docstring_end_index + 1 :]", self.lines[docstring_end_index + 1 :])
# If a docstring exists, replace it with the new one
if docstring_start_index is not None and docstring_end_index is not None:
self.lines = (
self.lines[:docstring_start_index]
+ docstring
+ self.lines[docstring_end_index + 1 :]
)
else:
# Insert the generated docstring directly after the shebang (no extra newline)
self.lines = [shebang] + docstring + self.lines[1:]
# self.print_debug("file docstring: " + self.file_path, self.lines)
return len(docstring)
def generate_docstrings(self):
if len(self.lines) == 0:
return
self.line_index = self.generate_file_docstring()
while self.line_index < len(self.lines):
line = self.lines[self.line_index]
# For classes, generate class docstring
if line.strip().startswith("class "):
if not self.generate_class_docstring():
continue # Skip to the next line
# For functions, generate function docstring
elif line.strip().startswith("def "):
if not self.generate_function_docstring():
continue # Skip to the next line
self.line_index = self.line_index + 1
self.complete()
def process_file(file_path, debug=False, exit=False):
"""Process a single file by generating docstrings."""
Docstring(file_path, debug=debug, exit=exit).generate_docstrings()
def process_directory(directory_path, recursive=False, debug=False, exit=False):
"""Process all Python files in the directory with progress tracking."""
# List all python files
python_files = [
os.path.join(root, file)
for root, dirs, files in os.walk(directory_path)
for file in files
if file.endswith(".py")
]
total_files = len(python_files) # Total python files count
processed_files = 0
print(f"Processing {total_files} Python files...")
for file_path in python_files:
process_file(file_path, debug=debug, exit=exit)
processed_files += 1
print(f"\nProcessing file: {processed_files}/{total_files}")
# If we don't want recursive traversal, we break the loop once we're done with this level
if not recursive:
break
def main():
# Set up the argument parser
parser = argparse.ArgumentParser(
description="Generate file-level docstrings for Python files."
)
parser.add_argument("path", help="Path to a Python file or directory.")
parser.add_argument(
"-r",
"--recursive",
action="store_true",
help="Recursively process all Python files in the directory.",
)
parser.add_argument(
"-d",
"--debug",
action="store_true",
help="Copies failed updates to /tmp/",
)
parser.add_argument(
"-e",
"--exit",
action="store_true",
help="Exits on failure",
)
# Parse the arguments
args = parser.parse_args()
# Check if the path is a file or directory
if os.path.isfile(args.path):
# If it's a file, process it directly
process_file(args.path, debug=args.debug, exit=args.exit)
elif os.path.isdir(args.path):
# If it's a directory, process all Python files
process_directory(
args.path, recursive=args.recursive, debug=args.debug, exit=args.exit
)
else:
print(f"Error: {args.path} is neither a valid file nor a directory.")
OpenAICost.print_cost_metrics()
if __name__ == "__main__":
main()
#!/usr/bin/python
import argparse
import hashlib
import json
import os
import re
import statistics
import subprocess
import sys
import tempfile
import time
import requests
system_prompt_general = """
You will be provided with a Python code.
Based on this, your task is to generate a Python docstring for it.
This is the docstring that goes at the top of the file.
The file may include multiple classes or other code.
Ensure the docstring follows Python's standard docstring conventions and provides
just enough detail to make the file understandable and usable without overwhelming the reader.
Please only return the docstring, enclosed in triple quotes, without any other explanation
or additional text. The format should be:
\"\"\"
\"\"\"
Make sure to follow the format precisely and provide only the docstring content.
"""
system_prompt_class = """
You will be provided with a Python class, including its code.
Based on this, your task is to generate a Python docstring for it.
Use the class signature and body to infer the purpose of the class,
the attributes it has, and any methods it includes.
Follow these guidelines to create the docstring:
1. Summary: Provide a concise summary of the class's purpose.
Focus on what the class does and its main goal.
2. Attributes: List the attributes, their types, and a brief description
of what each one represents.
Ensure the docstring follows Python's standard docstring conventions and provides
just enough detail to make the class understandable and usable without overwhelming the reader.
Please only return the docstring, enclosed in triple quotes, without any other
explanation or additional text. The format should be:
\"\"\"
\"\"\"
Make sure to follow the format precisely and provide only the docstring content.
"""
system_prompt_def = """
You will be provided with a Python function, including its code.
Based on this, your task is to generate a Python docstring for it.
Use the function signature and body to infer the purpose of the function,
the arguments it takes, the return value, and any exceptions it may raise.
Follow these guidelines to create the docstring:
1. Summary: Provide a concise summary of the function's purpose.
Focus on what the function does and its main goal.
2. Arguments: List the parameters, their types, and a brief description
of what each one represents.
3. Return: If the function has a return value, describe the return type
and what it represents. If there's no return, OMIT THE SECTION.
4. Exceptions: If the function raises any exceptions, list them with descriptions.
If no exceptions are raised, OMIT THE SECTION.
5. Side Effects (if applicable): If the function has side effects
(e.g., modifies global state, interacts with external services), mention them.
OMIT THE SECTION if it is not clear in the code.
6. Algorithm or Key Logic (optional): If the function is complex,
provide a high-level outline of the logic or algorithm involved.
OMIT THE SECTION if it is not clear in the code.
Ensure the docstring follows Python's standard docstring conventions and provides
just enough detail to make the function understandable and usable without overwhelming the reader.
Please only return the docstring, enclosed in triple quotes, without any other
explanation or additional text. The format should be:
\"\"\"
\"\"\"
Make sure to follow the format precisely and provide only the docstring content.
"""
class OpenAICost:
# Static member to track the total cost
cost = 0
costs = [] # A list to store the individual request costs
@staticmethod
def send_cost(tokens, model):
# The cost calculation can vary based on the model and tokens
model_costs = {
"gpt-3.5-turbo": 0.002, # Example cost per 1k tokens
"gpt-4o-mini": 0.003, # Example cost per 1k tokens
"gpt-4o": 0.005, # Example cost per 1k tokens
}
cost_per_token = model_costs.get(model, 0)
cost = (tokens / 1000) * cost_per_token # Cost is proportional to tokens
OpenAICost.cost += cost
OpenAICost.costs.append(cost)
@staticmethod
def print_cost_metrics():
print(f"\nTotal Cost: ${OpenAICost.cost:.4f}")
if OpenAICost.costs:
print(
f" Average Cost per Request: ${statistics.mean(OpenAICost.costs):.4f}"
)
print(f" Max Cost for a Request: ${max(OpenAICost.costs):.4f}")
print(f" Min Cost for a Request: ${min(OpenAICost.costs):.4f}")
if len(OpenAICost.costs) > 1:
print(
f" Standard Deviation of Cost: ${statistics.stdev(OpenAICost.costs):.4f}"
)
else:
print(" No costs recorded.")
class OpenAIProvider:
def __init__(self):
self.api_key = os.getenv("JIRA_AI_API_KEY")
if not self.api_key:
raise EnvironmentError("JIRA_AI_API_KEY not set in environment.")
self.endpoint = "https://api.openai.com/v1/chat/completions"
self.model = os.getenv("OPENJIRA_AI_MODEL", "gpt-4")
def estimate_tokens(self, text: str) -> int:
tokens = len(text) // 3 # Using the 3 bytes per token estimation
return tokens
def select_model(self, input):
tokens = self.estimate_tokens(input)
if tokens < 1000: # For small files (under ~1000 tokens)
model = "gpt-3.5-turbo"
elif tokens < 10000: # For medium files (under ~10000 tokens)
model = "gpt-4o-mini"
else: # For large files (over ~10000 tokens)
model = "gpt-4o"
OpenAICost.send_cost(tokens, model)
return model
def improve_text(self, prompt: str, text: str) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
body = {
"model": self.select_model(text),
"messages": [
{"role": "system", "content": prompt},
{"role": "user", "content": text},
],
"temperature": 0.5,
}
response = requests.post(self.endpoint, json=body, headers=headers, timeout=300)
if response.status_code == 200:
res = response.json()["choices"][0]["message"]["content"].strip()
result = res
# Count the occurrences of '"""'
occurrences = result.count('"""')
occurrences_backtick = result.count("```")
# Check if there are more than two occurrences
if occurrences > 2:
# Find the positions of the first and second occurrences
first_pos = result.find('"""')
second_pos = result.find('"""', first_pos + 1)
# Get everything from the first '"""' to the second '"""', inclusive
result = result[first_pos : second_pos + 3] # Include the second '"""'
if occurrences_backtick > 2:
# Find the positions of the first and second occurrences
first_pos = result.find("```")
second_pos = result.find("```", first_pos + 1)
# Get everything from the first '"""' to the second '"""', inclusive
result = result[first_pos : second_pos + 3] # Include the second '"""'
return result
raise Exception(
f"OpenAI API call failed: {response.status_code} - {response.text}"
)
class Docstring:
def __init__(self, file_path, debug=False, exit=False):
self.file_path = file_path
self.lines = []
self.ai = OpenAIProvider()
self.line_index = 0
self.multiline_index = 0
self.cache_file = "docstring.cache"
self.debug = debug
self.exit = exit
with open(self.file_path, "r") as file:
self.lines = file.readlines()
# Ensure cache file exists, create if necessary
if not os.path.exists(self.cache_file):
with open(self.cache_file, "w") as cache:
json.dump({}, cache)
self._load_cache()
print(" -> " + self.file_path)
def print_debug(self, title, out):
if not self.debug:
return
print("=====================================================")
print("=====================================================")
print(" > > " + title)
print("=====================================================")
out = "".join(out) if isinstance(out, list) else out
print(out)
print("=====================================================")
print("=====================================================")
def _load_cache(self):
with open(self.cache_file, "r") as cache:
self.cache = json.load(cache)
def _save_cache(self):
with open(self.cache_file, "w") as cache_file:
json.dump(self.cache, cache_file, indent=4)
def _generate_sha1(self, user_prompt):
return hashlib.sha1(user_prompt.encode("utf-8")).hexdigest()
def _get_current_timestamp(self):
return int(time.time())
def get_ai_docstring(self, sys_prompt, user_prompt, signiture):
sha1_hash = self._generate_sha1(user_prompt)
# Check if file is in cache
if self.file_path in self.cache:
# Check if the user prompt's SHA1 is in the self.cache for this file
for entry in self.cache[self.file_path]:
if entry["sha1"] == sha1_hash:
# Update last_accessed timestamp
entry["last_accessed"] = self._get_current_timestamp()
# Return cached docstring if found
return entry["docstring"]
print(" Requesting AI for: " + signiture)
# If no self.cache hit, call the AI and get the docstring
res = self.ai.improve_text(sys_prompt, user_prompt)
# Create a new self.cache entry with last_accessed timestamp
new_entry = {
"sha1": sha1_hash,
"docstring": res,
"last_accessed": self._get_current_timestamp(),
}
# Add new entry to the self.cache for the current file
if self.file_path not in self.cache:
self.cache[self.file_path] = []
self.cache[self.file_path].append(new_entry)
# Return the new docstring from AI
return res
def remove_old_entries(self, minutes):
current_timestamp = self._get_current_timestamp()
threshold_timestamp = current_timestamp - (minutes * 60)
# Remove old entries for each file in the self.cache
for file_path, entries in self.cache.items():
self.cache[file_path] = [
entry
for entry in entries
if "last_accessed" in entry
and entry["last_accessed"] >= threshold_timestamp
]
def wrap_text(self, text: str, max_length=120, indent=0):
wrapped_lines = []
lines = text.strip().splitlines()
# Ensure indent is an integer
try:
indent = int(indent)
except ValueError:
indent = 0 # Default to 0 if it’s an invalid value
spacer = ""
if indent == 0:
spacer = ""
else:
spacer = " " * (indent * 4)
for line in lines:
line = spacer + line.strip()
while len(line) > max_length:
# Find last space to split at
split_at = line.rfind(" ", 0, max_length)
if split_at == -1:
split_at = max_length # no space found, split at max_length
wrapped_lines.append(line[:split_at].rstrip())
line = spacer + line[split_at:].strip()
wrapped_lines.append(line)
return wrapped_lines
def count_and_divide_whitespace(self, line):
leading_whitespace = len(line) - len(line.lstrip())
if leading_whitespace == 0:
return 0
return leading_whitespace // 4
def complete(self):
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file_path = temp_file.name
temp_file.write(
"".join(self.lines).encode()
) # Write the content to the temporary file
try:
# Attempt to compile the temporary file
result = subprocess.run(
["python", "-m", "py_compile", temp_file_path],
capture_output=True,
text=True,
)
# If there is no compilation error (i.e., result.returncode == 0), move the file to the destination
if result.returncode == 0:
with open(self.file_path, "w") as file:
print(f" Wrote: {self.file_path}")
file.write("".join(self.lines)) # Write to the destination file
else:
print(f" Error compiling file: {result.stderr}")
if self.debug:
name = "/tmp/" + os.path.basename(self.file_path) + ".failed"
with open(name, "w") as file:
file.write("".join(self.lines))
print(f" Copied here: {name}")
if self.exit:
sys.exit(1)
finally:
# Clean up the temporary file
os.remove(temp_file_path)
self.remove_old_entries(1440 * 14)
self._save_cache()
def generate_class_docstring(self):
line = self.lines[self.line_index]
class_definition = line
output = re.sub("\\s+", " ", class_definition.rstrip().replace("\n", " "))
print(" -> " + output)
prompt_class_code = [class_definition]
if self.count_and_divide_whitespace(class_definition) > 0:
self.line_index = self.line_index + 1
return None
t = self.line_index + 1
# Collect all lines that belong to the class, including "pass" or single-line classes
while (
t < len(self.lines)
and not self.lines[t].startswith("def")
and not self.lines[t].startswith("class")
):
prompt_class_code.append(self.lines[t].rstrip())
t += 1
class_docstring = self.get_ai_docstring(
system_prompt_class, "\n".join(prompt_class_code), output
)
class_docstring = self.wrap_text(class_docstring, max_length=120, indent=1)
class_docstring[len(class_docstring) - 1] = (
class_docstring[len(class_docstring) - 1] + "\n"
)
class_docstring = [line + "\n" for line in class_docstring]
class_docstring[len(class_docstring) - 1] = (
class_docstring[len(class_docstring) - 1].rstrip() + "\n"
)
# Check for existing docstring and replace it
docstring_start_index = None
docstring_end_index = None
# Look for the class docstring (the second line should start with """ if it's there)
if self.lines and self.lines[self.line_index + 1].strip().startswith('"""'):
# Docstring exists, find the end of it
docstring_start_index = (
self.line_index + 1
) # The docstring starts from line after the class definition
for i, line in enumerate(
self.lines[self.line_index + 2 :], start=self.line_index + 2
):
if line.strip().startswith('"""'):
docstring_end_index = i # End of the docstring
break
# If a docstring exists, replace it
if docstring_start_index is not None and docstring_end_index is not None:
self.lines = (
self.lines[:docstring_start_index]
+ self.lines[docstring_end_index + 1 :]
)
# Insert the new docstring after the class definition
self.lines = (
self.lines[: self.line_index + 1]
+ class_docstring
+ self.lines[self.line_index + 1 :]
)
# self.print_debug("class docstring: " + class_definition.strip(), self.lines)
return True
def generate_function_docstring(self):
line = self.lines[self.line_index]
mutliline_line = ""
if not (
line.strip().endswith("):")
and not re.search(r"\)\s*->\s*(.*)\s*:.*", line.strip())
):
self.multiline_index = 0
# multiline def signiture
while self.line_index < len(self.lines):
mutliline_line += self.lines[self.line_index]
if re.match(
r".*\):$", self.lines[self.line_index].strip()
) or re.search(
r".*\)\s*->\s*(.*)\s*:.*", self.lines[self.line_index].strip()
):
break
self.line_index += 1
self.multiline_index += 1
def_definition = line if mutliline_line == "" else mutliline_line
output = re.sub("\\s+", " ", def_definition.rstrip().replace("\n", " "))
print(" -> " + output)
prompt_def_code = (
mutliline_line.split("\n") if mutliline_line != "" else [def_definition]
)
indent_line = self.count_and_divide_whitespace(
def_definition if mutliline_line == "" else mutliline_line.splitlines()[0]
)
spacer_line = "" if indent_line == 0 else " " * (indent_line * 4)
spacer_line_minus = "" if indent_line < 2 else " " * ((indent_line - 1) * 4)
spacer_line_plus = "" if indent_line == 0 else " " * ((indent_line + 1) * 4)
t = self.line_index + 1
# Collect all self.lines that belong to the function
while t < len(self.lines):
starts_with_def = self.lines[t].strip().startswith("def")
# same indent
if starts_with_def and self.lines[t].startswith(spacer_line):
break
# outside
if starts_with_def and self.lines[t].startswith(spacer_line_minus):
break
# nested
if starts_with_def and self.lines[t].startswith(spacer_line_plus):
pass
if self.lines[t].rstrip() != def_definition.strip():
prompt_def_code.append(self.lines[t].rstrip())
t += 1
# Now that we have the full function signature, we generate the docstring
indent = (
self.count_and_divide_whitespace(
def_definition
if mutliline_line == ""
else mutliline_line.splitlines()[0]
)
+ 1
)
def_docstring = self.get_ai_docstring(
system_prompt_def, "\n".join(prompt_def_code), output
)
def_docstring = self.wrap_text(def_docstring, max_length=120, indent=indent)
def_docstring[len(def_docstring) - 1] = (
def_docstring[len(def_docstring) - 1] + "\n"
)
if def_definition.strip() == def_docstring[0].strip():
def_docstring = def_docstring[
1 if mutliline_line == "" else self.multiline_index :
]
def_docstring = [line + "\n" for line in def_docstring]
def_docstring[len(def_docstring) - 1] = (
def_docstring[len(def_docstring) - 1].rstrip() + "\n"
)
# Handle one-liner docstring or multi-line docstring
if '"""' in self.lines[self.line_index + 1]:
stripped_line = self.lines[self.line_index + 1].strip()
if re.match(r'"""[\s\S]+?"""', stripped_line):
# This is a one-liner or multi-line docstring (we always replace with a multi-line docstring)
self.lines = (
self.lines[: self.line_index + 1]
+ def_docstring
+ self.lines[self.line_index + 2 :]
)
else:
# Replace the entire docstring if it's multi-line
end_index = self.line_index + 2
while end_index < len(self.lines) and not self.lines[
end_index
].strip().startswith('"""'):
end_index += 1
if (
end_index < len(self.lines)
and self.lines[end_index].strip() == '"""'
):
# Found the end of the docstring, now replace the entire docstring block
self.lines = (
self.lines[: self.line_index + 1]
+ def_docstring
+ self.lines[end_index + 1 :]
)
else:
# If no docstring exists, simply insert the generated docstring
self.lines = (
self.lines[: self.line_index + 1]
+ def_docstring
+ self.lines[self.line_index + 1 :]
)
# self.print_debug("def docstring: " + def_definition.strip(), self.lines)
self.line_index = self.line_index + len(def_docstring)
return True
def generate_file_docstring(self):
# Check if we should add a file-level docstring
# if not self.should_add_file_docstring():
# return 0 # Skip generating file-level docstring if not needed
shebang = ""
# Check if the first line starts with a shebang (e.g., #! anything)
if self.lines and not self.lines[0].startswith("#!"):
self.lines = ["#!/usr/bin/env python\n"] + self.lines
shebang = "#!/usr/bin/env python\n"
else:
shebang = self.lines[0]
# Check if there's already an existing file-level docstring or comment block
# We assume the file-level docstring starts with triple quotes (""" or ''') and is at the top
docstring_start_index = None
docstring_end_index = None
if self.lines and self.lines[1].strip().startswith('"""'):
# If the second line starts with triple quotes, it may be a docstring
docstring_start_index = 1 # The docstring starts from line 2
for i, line in enumerate(self.lines[2:], start=2):
if line.strip().startswith('"""'):
docstring_end_index = i # End of the docstring
break
# Generate new file-level docstring
general_description = self.get_ai_docstring(
system_prompt_general, "".join(self.lines), self.file_path
)
general_description = self.wrap_text(
general_description, max_length=120, indent=0
)
docstring = [line + "\n" for line in general_description]
# self.print_debug("docstring_end_index", str(docstring_end_index))
# self.print_debug("self.lines[:docstring_start_index]", self.lines[:docstring_start_index])
# self.print_debug("docstring", docstring)
# self.print_debug("self.lines[docstring_end_index + 1 :]", self.lines[docstring_end_index + 1 :])
# If a docstring exists, replace it with the new one
if docstring_start_index is not None and docstring_end_index is not None:
self.lines = (
self.lines[:docstring_start_index]
+ docstring
+ self.lines[docstring_end_index + 1 :]
)
else:
# Insert the generated docstring directly after the shebang (no extra newline)
self.lines = [shebang] + docstring + self.lines[1:]
# self.print_debug("file docstring: " + self.file_path, self.lines)
return len(docstring)
def generate_docstrings(self):
if len(self.lines) == 0:
return
self.line_index = self.generate_file_docstring()
while self.line_index < len(self.lines):
line = self.lines[self.line_index]
# For classes, generate class docstring
if line.strip().startswith("class "):
if not self.generate_class_docstring():
continue # Skip to the next line
# For functions, generate function docstring
elif line.strip().startswith("def "):
if not self.generate_function_docstring():
continue # Skip to the next line
self.line_index = self.line_index + 1
self.complete()
def process_file(file_path, debug=False, exit=False):
"""Process a single file by generating docstrings."""
Docstring(file_path, debug=debug, exit=exit).generate_docstrings()
def process_directory(directory_path, recursive=False, debug=False, exit=False):
"""Process all Python files in the directory with progress tracking."""
# List all python files
python_files = [
os.path.join(root, file)
for root, dirs, files in os.walk(directory_path)
for file in files
if file.endswith(".py")
]
total_files = len(python_files) # Total python files count
processed_files = 0
print(f"Processing {total_files} Python files...")
for file_path in python_files:
process_file(file_path, debug=debug, exit=exit)
processed_files += 1
print(f"\nProcessing file: {processed_files}/{total_files}")
# If we don't want recursive traversal, we break the loop once we're done with this level
if not recursive:
break
def main():
# Set up the argument parser
parser = argparse.ArgumentParser(
description="Generate file-level docstrings for Python files."
)
parser.add_argument("path", help="Path to a Python file or directory.")
parser.add_argument(
"-r",
"--recursive",
action="store_true",
help="Recursively process all Python files in the directory.",
)
parser.add_argument(
"-d",
"--debug",
action="store_true",
help="Copies failed updates to /tmp/",
)
parser.add_argument(
"-e",
"--exit",
action="store_true",
help="Exits on failure",
)
# Parse the arguments
args = parser.parse_args()
# Check if the path is a file or directory
if os.path.isfile(args.path):
# If it's a file, process it directly
process_file(args.path, debug=args.debug, exit=args.exit)
elif os.path.isdir(args.path):
# If it's a directory, process all Python files
process_directory(
args.path, recursive=args.recursive, debug=args.debug, exit=args.exit
)
else:
print(f"Error: {args.path} is neither a valid file nor a directory.")
OpenAICost.print_cost_metrics()
if __name__ == "__main__":
main()
Now, let's break down how this script works:
- OpenAICost Class: This class tracks the cost of using the OpenAI API based on the number of tokens used. It provides methods to send cost data and print cost metrics.
- OpenAIProvider Class: This class handles the interaction with the OpenAI API. It estimates the number of tokens in the input text, selects the appropriate model based on the token count, and sends requests to improve the text.
- Docstring Class: This is the heart of the script. It reads the Python file, generates docstrings for classes and functions using the OpenAI API, and saves the updated file. It also manages a cache for previously generated docstrings to optimize performance.
- Main Functionality: The script can process a single file or an entire directory of Python files, generating docstrings for each. It includes command-line arguments for flexibility.
This tool can significantly enhance your coding workflow by ensuring that your code is well-documented and easy to understand. I hope you find it as useful as I do! Happy coding!
📚 Further Learning
🎥 Watch this video for more: