Skip to content

Commit

Permalink
[External] [stdlib] add b64decode (#38800)
Browse files Browse the repository at this point in the history
[External] [stdlib] add b64decode

Followed the decode algorithm from the same paper used for `b64encode`.

The Llama3 `tokenizer.model` stores the tokens with base64 encoding so
demand for this may increase. 😃

ORIGINAL_AUTHOR=Michael Kowalski
<1331470+mikowals@users.noreply.github.com>
PUBLIC_PR_LINK=modularml#2364

---------

Co-authored-by: Michael Kowalski <1331470+mikowals@users.noreply.github.com>
Closes modularml#2364
MODULAR_ORIG_COMMIT_REV_ID: de91cca69272570a52fcbf28a5c51c8d7fe75364
  • Loading branch information
2 people authored and JoeLoser committed Apr 30, 2024
1 parent 06d3e48 commit 2af3b65
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 2 deletions.
2 changes: 1 addition & 1 deletion stdlib/src/base64/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# ===----------------------------------------------------------------------=== #
"""Implements the base64 package."""

from .base64 import b64encode
from .base64 import b64encode, b64decode
80 changes: 80 additions & 0 deletions stdlib/src/base64/base64.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@ from base64 import b64encode
from collections import List
from sys import simdwidthof

# ===----------------------------------------------------------------------===#
# Utilities
# ===----------------------------------------------------------------------===#


@always_inline
fn _ascii_to_value(char: String) -> Int:
"""Converts an ASCII character to its integer value for base64 decoding.
Args:
char: A single character string.
Returns:
The integer value of the character for base64 decoding, or -1 if invalid.
"""
var char_val = ord(char)

if char == "=":
return 0
elif ord("A") <= char_val <= ord("Z"):
return char_val - ord("A")
elif ord("a") <= char_val <= ord("z"):
return char_val - ord("a") + 26
elif ord("0") <= char_val <= ord("9"):
return char_val - ord("0") + 52
elif char == "+":
return 62
elif char == "/":
return 63
else:
return -1


# ===----------------------------------------------------------------------===#
# b64encode
# ===----------------------------------------------------------------------===#
Expand Down Expand Up @@ -71,3 +104,50 @@ fn b64encode(str: String) -> String:
out.append(ord("="))
out.append(0)
return String(out^)


# ===----------------------------------------------------------------------===#
# b64decode
# ===----------------------------------------------------------------------===#


@always_inline
fn b64decode(str: String) -> String:
"""Performs base64 decoding on the input string.
Args:
str: A base64 encoded string.
Returns:
The decoded string.
"""
var n = len(str)
debug_assert(n % 4 == 0, "Input length must be divisible by 4")

var p = List[Int8](capacity=n + 1)

# This algorithm is based on https://arxiv.org/abs/1704.00605
for i in range(0, n, 4):
var a = _ascii_to_value(str[i])
var b = _ascii_to_value(str[i + 1])
var c = _ascii_to_value(str[i + 2])
var d = _ascii_to_value(str[i + 3])

debug_assert(
a >= 0 and b >= 0 and c >= 0 and d >= 0,
"Unexpected character encountered",
)

p.append((a << 2) | (b >> 4))
if str[i + 2] == "=":
break

p.append(((b & 0x0F) << 4) | (c >> 2))

if str[i + 3] == "=":
break

p.append(((c & 0x03) << 6) | d)

p.append(0)
return p
22 changes: 21 additions & 1 deletion stdlib/test/base64/test_base64.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# ===----------------------------------------------------------------------=== #
# RUN: %mojo %s

from base64 import b64encode
from base64 import b64encode, b64decode
from testing import assert_equal


Expand All @@ -33,5 +33,25 @@ def test_b64encode():
assert_equal(b64encode("ABCDEFabcdef"), "QUJDREVGYWJjZGVm")


def test_b64decode():
assert_equal(b64decode("YQ=="), "a")

assert_equal(b64decode("Zm8="), "fo")

assert_equal(b64decode("SGVsbG8gTW9qbyEhIQ=="), "Hello Mojo!!!")

assert_equal(b64decode("SGVsbG8g8J+UpSEhIQ=="), "Hello 🔥!!!")

assert_equal(
b64decode(
"dGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIHRoZSBsYXp5IGRvZw=="
),
"the quick brown fox jumps over the lazy dog",
)

assert_equal(b64decode("QUJDREVGYWJjZGVm"), "ABCDEFabcdef")


def main():
test_b64encode()
test_b64decode()

0 comments on commit 2af3b65

Please sign in to comment.