Enhancing Code Completion for Rust in Cody
Although most LLMs are trained on corpora that include several programming languages, we often observe differential performance across languages, especially languages like Rust that are not well represented in popular training datasets. In this post, we share early results from our efforts to improve the performance of LLMs for code completion in such languages.
Developers use coding assistants like Cody to improve their workflows, relying on autocomplete to provide real-time code completions as a developer is typing in their IDE. As you start coding, or after you type a comment, Cody performs the following loop dozens of times per minute:
-
Look at the context around your open files, file history, and broader repository context
-
Stitch together a prompt and trigger a completion request to a large language model (LLM)
-
Post-process the LLM output, truncating multi-line completions and using tree-sitter based syntactic parsing for filtering bad suggestions
-
Displays the final completion suggestion to the user.
Here's our blog describing the lifecycle of the autocomplete request in detail.
A standard benchmark used to evaluate the performance of LLMs on code completions is the HumanEval benchmark, which measures the functional correctness of programs generated by LLMs from docstrings. The standard metric used on this benchmark is pass@k, a metric designed to evaluate the functional correctness of generated code samples. The pass@k metric is defined as the probability that at least one of the top k-generated code samples for a problem passes the unit tests. Here's how popular LLMs perform in a variety of languages. Note the difference in performance between Python versus Rust, Ruby, and Matlab:
The figure on the left shows the pass@1 metric for proprietary models and the figure on the right presents results for OSS models. In both proprietary and OSS models, there is a drastic performance drop on the pass@1 metric between languages like Python and JavaScript versus Ruby, PHP, Rust, and Matlab. This is likely due to the relative representation of each language in the training data and the empirical complexity of the language's syntax.
An important consideration for us as we improve code completion quality for Cody is to bridge this performance gap across languages. Specifically, we ask the question: can we finetune an LLM to make it perform better code completions on a language of interest? Let's find out. But first, let's consider what metrics and trade-offs are important for any LLM to be useful for Cody's autocomplete feature.
Striking the Perfect Balance: Latency vs. Pass@1
When it comes to code completion, two metrics dominate our attention: latency and Pass@1. Latency measures the response time of the coding assistant, which is crucial for surfacing results before the next keystroke. Pass@1 assesses the accuracy of the generated code to pass predefined unit tests.
Let's look at a scatter plot comparing the latency and pass@1 metrics for a few popular LLMs for the Rust language. The scatterplot below plots the pass@1 (x-axis) and decreasing mean latency (y-axis) for a set of models. The top-right is best and the green region shows the tolerable latency for autocomplete.
We observe that GPT4-turbo and Claude-3-opus offer significantly better performance on pass@1 metric, but their latencies are prohibitively high—often as high as multiple seconds—for a latency-sensitive feature like autocomplete, this is a deal breaker. Ideally, we'd like to be within <500ms end-to-end to give developers a seamless experience using autocomplete. If we focus on the results for models that are within 1000ms latency budget, mixtral-8x7b, codellama-34b-instruct, and starcoder-16b emerge as promising candidate models to consider.
Given this, the question we consider then becomes: can we fine-tune an LLM to improve code completion quality for a language of our choice?
Fine-tuning a code completion model for Rust
Rust presents unique challenges for code completion tools due to its stringent constraints on safety in concurrency and memory management. As shown earlier, general-purpose models underperformed in Rust, struggling with the language's advanced syntax and semantics.
We use LoRA (Low-Rank Adaptation) for efficient and effective fine-tuning of the Mixtral 8x7b and Code Llama 34b LLMs. We developed a finetuning dataset of Rust repositories with permissive licenses, and simulated code completion requests on a random selection of Rust files within these repositories. We further applied a selection strategy to target functions that would result in a non-trivial completion and masked a meaningful section of the code in the containing file. We applied LoRA with a lora_rank
of 16, which refers to the dimensionality of trainable matrices in low-rank adaptation fine-tuning, balancing model adaptability and computational efficiency in fine-tuning large language models, and a batch size of 32 with learning_rate: 0.0001
.
In one of our experiments, we finetuned two base models on the same dataset: Mixtral 8x-7b and Code Llama 34b. The finetuned Mixtral showed a clear superiority in the benchmarks. The figure below compares our finetuned Mixtral with the base Mixtral and the StarCoder-16b model.
For example, before fine-tuning, our models averaged a Pass@1 of around 0.28 for Rust; whereas the fine-tuned mixtral performance jumped to over 0.39. This leap is not just numerical but reflects a qualitative improvement in how the model understands and generates Rust code, as illustrated below.
Showcasing improvements using concrete examples
To illustrate the effectiveness of our fine-tuned Mixtral model, let's look at a sample of five examples to understand where the fine-tuned model performs better.
Example 1: Given a HashMap, return true if all keys are strings in lower case or all keys are strings in upper case, else return false. The function should return false if the given HashMap is empty.
Solution from mixtral-8x7b base model:
Solution from finetuned mixtral-8x7b model:
In the solution from the base model, the function attempts to directly call the method is_ascii_lowercase()
on the String
type. However, no such method is available for the String
type, which results in a compilation error. In the solution from the fine-tuned model, the function iterates over the keys of the HashMap
using dict.iter()
to access each key-value pair. It then checks each character of the key using k.chars()
and applies the is_ascii_lowercase()
and is_ascii_uppercase()
methods to each character individually. This allows for correct checking of the case of each character in the key. Additionally, it also includes a check for an empty HashMap
at the beginning of the function, which returns false if the HashMap
is empty.
Example 2: Return the largest prime factor of n. Assume n > 1 and is not a prime.
Solution from mixtral-8x7b base model:
Solution from finetuned mixtral-8x7b model:
The base model's solution for finding the largest prime factor of a number fails due to an attempt to modify an immutable function argument, resulting in a compilation error (error[E0384]: cannot assign to immutable argument 'n'
). The finetuned model corrects this issue by introducing a mutable binding for n
within the function (let mut n = n;
), allowing modifications. Additionally, the finetuned model enhances the algorithm by tracking the largest factor found and efficiently removing factors of 2 before checking only odd numbers. This ensures a more accurate and efficient determination of the largest prime factor, addressing both the immutability error and logical completeness present in the base model's approach.
Example 3: You're a hungry rabbit, and you already have eaten a certain number of carrots, but now you need to eat more carrots to complete the day's meals. You should return a vector of [ total number of eaten carrots after your meals, the number of carrots left after your meals ] if there are not enough remaining carrots, you will eat all remaining carrots, but will still be hungry.
Solution from mixtral-8x7b base model:
Solution from finetuned mixtral-8x7b model:
The base model's solution for calculating the total number of carrots eaten and the remaining carrots after meals contains a logical error in how it handles the calculation and assignment to the result vector. Specifically, it mismanages the situation where the total needed carrots exceed the remaining carrots. The base model incorrectly assigns result[0]
to the remaining carrots and result[1]
to the difference between total and remaining, leading to inaccurate results. This causes the assertion failures in tests where the expected total eaten carrots and remaining carrots do not match the output. In contrast, the finetuned model correctly handles this calculation by first setting result[0]
to the sum of number and need, and result[1]
to the difference between remaining and need. It then checks if result[1]
is negative, indicating that the remaining carrots are insufficient to meet the need. If so, it adjusts result[0]
by adding the negative result[1] (reducing the total eaten carrots accordingly) and sets result[1]
to zero, accurately reflecting that all remaining carrots have been consumed. This approach ensures the result vector correctly represents the total eaten carrots and remaining carrots, passing all test cases without assertion errors.
In summary, the finetuned model provides a more robust solution by correctly adjusting for the cases where the needed carrots exceed the remaining stock, ensuring the output accurately reflects the total eaten and remaining carrots. This technical correction allows the finetuned model to produce correct results consistently across various test scenarios.
Example 4: This function takes a vector l and returns a vector l' such that l' is identical to l in the odd indicies, while its values at the even indicies are equal to the values of the even indicies of l, but sorted.
Solution from mixtral-8x7b base model:
Solution from finetuned mixtral-8x7b model:
The base model's solution for sorting the values at the even indices of a vector contains a fundamental misunderstanding of how to modify and sort slices within Rust. Specifically, the use of l.splice(i..i + 2, l[i..i + 2].sort());
is incorrect because the sort()
method returns ()
, which is not an iterator, leading to the error: "() is not an iterator.
" This mistake indicates a misunderstanding of Rust's slicing and sorting mechanisms. It is almost as if the model is pattern-matching against patterns in another language and naively interpolating these into Rust-like syntax.
In contrast, the finetuned model provides a correct implementation by explicitly sorting the elements at even indices using nested loops. The outer loop iterates through the indices of the vector, and the inner loop compares and swaps elements at even indices if they are out of order. This ensures that the even-indexed elements are sorted without altering the odd-indexed elements. The finetuned model's method is more verbose but correctly manipulates the vector by directly comparing and swapping elements, thus producing the correct sorted vector for even indices while leaving the odd indices unchanged.
In summary, the base model fails due to an incorrect use of slicing and sorting functions, resulting in a compilation error. The finetuned model correctly sorts the even-indexed elements by implementing a nested loop approach that directly compares and swaps elements, ensuring the even indices are sorted without affecting the odd indices. This solution passes all provided test cases, demonstrating its correctness and robustness.
Example 5: triples_sum_to_zero takes a vector of integers as an input. it returns true if there are three distinct elements in the vector that sum to zero, and false otherwise.
Solution from mixtral-8x7b base model:
Solution from finetuned mixtral-8x7b model:
The base model's solution for determining if a vector contains three distinct elements that sum to zero is incorrect due to a flawed approach that uses a hash set to check for pairs that sum to a given element's negative counterpart. This method only checks for pairs summing to zero rather than triples and fails to correctly identify all valid combinations, leading to incorrect results and assertion failures. The finetuned model, however, implements a correct approach by first sorting the vector and then using a three-pointer technique to find the triplets. This involves iterating through the vector with an index i
and using two additional pointers, j
and k
, to check the sum of elements at these positions. If the sum of the elements at i
, j
, and k
equals zero, the function returns true
. Otherwise, it adjusts the pointers based on whether the sum is less than or greater than zero. This method ensures that all possible triplets are checked efficiently.
In summary, the base model's error lies in its incorrect use of a hash set to find pairs rather than triplets, resulting in logical errors. The finetuned model correctly uses a sorted array and a three-pointer technique to accurately identify triplets summing to zero, demonstrating a more robust and effective solution. This correction ensures accurate results across various test cases.
Overall, looking at the above examples, the finetuned model offers advantages over the base model in solving coding problems accurately and efficiently. It handles errors and corrections effectively by using proper language constructs and correct logic, ensuring logical completeness by covering more edge cases and meeting problem requirements. The finetuned model employs efficient algorithms, enhancing performance and accuracy, as seen in examples like sorting arrays and using the three-pointer technique for triplet sum detection. Additionally, its solutions are clear and robust, leading to successful test case outcomes and fewer assertion errors. Overall, the finetuned model demonstrates an improved grasp of problem-solving techniques, providing more reliable and efficient solutions than the base model.
Looking ahead: enriching evals & improving other languages
The success we've seen with Rust is just the beginning. We are excited to share our progress on other languages in the coming weeks. An important part of improving code completions is developing evaluation suites that extend way beyond current human-eval benchmarks, and are more representative of real-world industrial coding experiences within complex code repositories. We have developed context-aware datasets and better metrics in this direction, and are excited to share more on this soon. We have rolled out our finetuned completions models to a subset of completion requests targeting the languages and request types that appear to benefit, and are observing increases in online metrics, which we intend to share in a follow-up post in the coming weeks.
We are thrilled about the improvements made and the potential for even greater enhancements in the future. If you're as excited as we are, we invite you to try Cody in your projects by installing it for VS Code or IntelliJ at cody.dev. Your feedback is invaluable as we continue to refine and expand Cody's capabilities. Together, let's push the boundaries of what coding assistants can do and transform the software development landscape.