init CTFd source
Some checks are pending
Linting / Linting (3.11) (push) Waiting to run
Mirror core-theme / mirror (push) Waiting to run

This commit is contained in:
gkr
2025-12-25 09:39:21 +08:00
commit 2e06f92c64
1047 changed files with 150349 additions and 0 deletions

24
tests/utils/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from CTFd.utils import get_config, set_config
from tests.helpers import create_ctfd, destroy_ctfd
def test_ctf_version_is_set():
"""Does ctf_version get set correctly"""
app = create_ctfd()
with app.app_context():
assert get_config("ctf_version") == app.VERSION
destroy_ctfd(app)
def test_get_config_and_set_config():
"""Does get_config and set_config work properly"""
app = create_ctfd()
with app.app_context():
assert get_config("setup") == True
config = set_config("TEST_CONFIG_ENTRY", "test_config_entry")
assert config.value == "test_config_entry"
assert get_config("TEST_CONFIG_ENTRY") == "test_config_entry"
destroy_ctfd(app)

239
tests/utils/test_ctftime.py Normal file
View File

@@ -0,0 +1,239 @@
from datetime import datetime as DateTime
from datetime import timezone as TimeZone
import pytest
from CTFd.models import Solves
from CTFd.utils.dates import (
ctf_ended,
ctf_started,
isoformat,
unix_time,
unix_time_millis,
unix_time_to_utc,
)
from CTFd.utils.modes import TEAMS_MODE
from tests.helpers import (
create_ctfd,
ctftime,
destroy_ctfd,
gen_challenge,
gen_flag,
gen_team,
login_as_user,
register_user,
)
def test_ctftime_prevents_accessing_challenges_before_ctf():
"""Test that the ctftime function prevents users from accessing challenges before the ctf"""
app = create_ctfd()
with app.app_context():
with ctftime.init():
register_user(app)
chal = gen_challenge(app.db)
chal_id = chal.id
gen_flag(app.db, challenge_id=chal.id, content="flag")
with ctftime.not_started():
client = login_as_user(app)
r = client.get("/challenges")
assert r.status_code == 403
with client.session_transaction() as sess:
data = {"key": "flag", "nonce": sess.get("nonce")}
r = client.get("/api/v1/challenges/{}".format(chal_id), data=data)
data = r.get_data(as_text=True)
assert r.status_code == 403
solve_count = app.db.session.query(app.db.func.count(Solves.id)).first()[0]
assert solve_count == 0
destroy_ctfd(app)
def test_ctftime_redirects_to_teams_page_in_teams_mode_before_ctf():
"""
Test that the ctftime function redirects users to the team creation page in teams mode before the ctf if the user
has no team yet.
"""
app = create_ctfd(user_mode=TEAMS_MODE)
with app.app_context():
with ctftime.init():
register_user(app)
chal = gen_challenge(app.db)
gen_flag(app.db, challenge_id=chal.id, content="flag")
with ctftime.not_started():
client = login_as_user(app)
r = client.get("/challenges")
assert r.status_code == 302
gen_team(app.db, name="test", password="password")
with login_as_user(app) as client:
r = client.get("/teams/join")
assert r.status_code == 200
with client.session_transaction() as sess:
data = {
"name": "test",
"password": "password",
"nonce": sess.get("nonce"),
}
r = client.post("/teams/join", data=data)
assert r.status_code == 302
with ctftime.not_started():
client = login_as_user(app)
r = client.get("/challenges")
assert r.status_code == 403
destroy_ctfd(app)
def test_ctftime_allows_accessing_challenges_during_ctf():
"""Test that the ctftime function allows accessing challenges during the ctf"""
app = create_ctfd()
with app.app_context():
with ctftime.init():
register_user(app)
chal = gen_challenge(app.db)
chal_id = chal.id
gen_flag(app.db, challenge_id=chal.id, content="flag")
with ctftime.started():
client = login_as_user(app)
r = client.get("/challenges")
assert r.status_code == 200
with client.session_transaction() as sess:
data = {
"submission": "flag",
"challenge_id": chal_id,
"nonce": sess.get("nonce"),
}
r = client.post("/api/v1/challenges/attempt", data=data)
assert r.status_code == 200
solve_count = app.db.session.query(app.db.func.count(Solves.id)).first()[0]
assert solve_count == 1
destroy_ctfd(app)
def test_ctftime_prevents_accessing_challenges_after_ctf():
"""Test that the ctftime function prevents accessing challenges after the ctf"""
app = create_ctfd()
with app.app_context():
with ctftime.init():
register_user(app)
chal = gen_challenge(app.db)
chal_id = chal.id
gen_flag(app.db, challenge_id=chal.id, content="flag")
with ctftime.ended():
client = login_as_user(app)
r = client.get("/challenges")
assert r.status_code == 403
with client.session_transaction() as sess:
data = {
"submission": "flag",
"challenge_id": chal_id,
"nonce": sess.get("nonce"),
}
r = client.post("/api/v1/challenges/attempt", data=data)
assert r.status_code == 403
solve_count = app.db.session.query(app.db.func.count(Solves.id)).first()[0]
assert solve_count == 0
destroy_ctfd(app)
def test_ctf_started():
"""
Tests that the ctf_started function returns the correct value
:return:
"""
app = create_ctfd()
with app.app_context():
assert ctf_started() is True
with ctftime.init():
with ctftime.not_started():
ctf_started()
assert ctf_started() is False
with ctftime.started():
assert ctf_started() is True
with ctftime.ended():
assert ctf_started() is True
destroy_ctfd(app)
def test_ctf_ended():
"""
Tests that the ctf_ended function returns the correct value
"""
app = create_ctfd()
with app.app_context():
assert ctf_ended() is False
with ctftime.init():
with ctftime.not_started():
assert ctf_ended() is False
with ctftime.started():
assert ctf_ended() is False
with ctftime.ended():
assert ctf_ended() is True
destroy_ctfd(app)
def test_unix_time():
"""
Tests that the unix_time function returns the correct value and fails gracefully for strange inputs
"""
assert unix_time(DateTime(2017, 1, 1)) == 1483228800
assert type(unix_time(DateTime(2017, 1, 1))) == int
assert unix_time(None) is None
assert unix_time("test") is None
assert unix_time(1) is None
def test_unix_time_millis():
"""
Tests that the unix_time function returns the correct value and fails gracefully for strange inputs
"""
# Aware datetime object
assert unix_time_millis(DateTime(2017, 1, 1)) == 1483228800000
assert type(unix_time_millis(DateTime(2017, 1, 1))) == int
assert unix_time_millis(None) is None
assert unix_time_millis("test") is None
assert unix_time_millis(1) is None
def test_unix_time_to_utc():
"""
Tests that the unix_time function returns the correct value and fails gracefully for strange inputs
"""
assert unix_time_to_utc(0) == DateTime(1970, 1, 1)
assert unix_time_to_utc(1483228800) == DateTime(2017, 1, 1)
assert type(unix_time_to_utc(1483228800)) == DateTime
assert unix_time_to_utc(None) is None
with pytest.raises(TypeError):
unix_time_to_utc("test")
with pytest.raises(TypeError):
unix_time_to_utc(DateTime(2017, 1, 1))
def test_isoformat():
"""
Tests that the unix_time function returns the correct value and fails gracefully for strange inputs
"""
assert (
isoformat(DateTime(2017, 1, 1, tzinfo=TimeZone.utc))
== "2017-01-01T00:00:00+00:00Z"
)
assert isoformat(DateTime(2017, 1, 1)) == "2017-01-01T00:00:00Z"
assert isoformat(DateTime(2017, 1, 1, tzinfo=None)) == "2017-01-01T00:00:00Z"
assert isoformat(None) is None
assert isoformat("test") is None
assert isoformat(1) is None

561
tests/utils/test_email.py Normal file
View File

@@ -0,0 +1,561 @@
from email.message import EmailMessage
from unittest.mock import Mock, patch
import requests
from CTFd.models import Users
from CTFd.utils import get_config, set_config
from CTFd.utils.crypto import verify_password
from CTFd.utils.email import (
check_email_is_blacklisted,
check_email_is_whitelisted,
forgot_password,
sendmail,
successful_registration_notification,
verify_email_address,
)
from tests.helpers import create_ctfd, destroy_ctfd, login_as_user, register_user
@patch("smtplib.SMTP")
def test_sendmail_with_smtp_from_config_file(mock_smtp):
"""Does sendmail work properly with simple SMTP mail servers using file configuration"""
app = create_ctfd()
with app.app_context():
app.config["MAIL_SERVER"] = "localhost"
app.config["MAIL_PORT"] = "25"
app.config["MAIL_USEAUTH"] = "True"
app.config["MAIL_USERNAME"] = "username"
app.config["MAIL_PASSWORD"] = "password"
ctf_name = get_config("ctf_name")
from_addr = get_config("mailfrom_addr") or app.config.get("MAILFROM_ADDR")
from_addr = "{} <{}>".format(ctf_name, from_addr)
to_addr = "user@user.com"
msg = "this is a test"
sendmail(to_addr, msg)
ctf_name = get_config("ctf_name")
email_msg = EmailMessage()
email_msg.set_content(msg)
email_msg["Subject"] = "Message from {0}".format(ctf_name)
email_msg["From"] = from_addr
email_msg["To"] = to_addr
mock_smtp.return_value.send_message.assert_called()
assert str(mock_smtp.return_value.send_message.call_args[0][0]) == str(
email_msg
)
destroy_ctfd(app)
@patch("smtplib.SMTP")
def test_sendmail_with_smtp_from_db_config(mock_smtp):
"""Does sendmail work properly with simple SMTP mail servers using database configuration"""
app = create_ctfd()
with app.app_context():
set_config("mail_server", "localhost")
set_config("mail_port", 25)
set_config("mail_useauth", True)
set_config("mail_username", "username")
set_config("mail_password", "password")
ctf_name = get_config("ctf_name")
from_addr = get_config("mailfrom_addr") or app.config.get("MAILFROM_ADDR")
from_addr = "{} <{}>".format(ctf_name, from_addr)
to_addr = "user@user.com"
msg = "this is a test"
sendmail(to_addr, msg)
ctf_name = get_config("ctf_name")
email_msg = EmailMessage()
email_msg.set_content(msg)
email_msg["Subject"] = "Message from {0}".format(ctf_name)
email_msg["From"] = from_addr
email_msg["To"] = to_addr
mock_smtp.return_value.send_message.assert_called()
assert str(mock_smtp.return_value.send_message.call_args[0][0]) == str(
email_msg
)
destroy_ctfd(app)
@patch.object(requests, "post")
def test_sendmail_with_mailgun_from_config_file(fake_post_request):
"""Does sendmail work properly with Mailgun using file configuration"""
app = create_ctfd()
with app.app_context():
app.config["MAILGUN_API_KEY"] = "key-1234567890-file-config"
app.config["MAILGUN_BASE_URL"] = "https://api.mailgun.net/v3/file.faked.com"
to_addr = "user@user.com"
msg = "this is a test"
sendmail(to_addr, msg)
fake_response = Mock()
fake_post_request.return_value = fake_response
fake_response.status_code = 200
status, message = sendmail(to_addr, msg)
args, kwargs = fake_post_request.call_args
assert args[0] == "https://api.mailgun.net/v3/file.faked.com/messages"
assert kwargs["auth"] == ("api", "key-1234567890-file-config")
assert kwargs["timeout"] == 1.0
assert kwargs["data"] == {
"to": ["user@user.com"],
"text": "this is a test",
"from": "CTFd <noreply@examplectf.com>",
"subject": "Message from CTFd",
}
assert fake_response.status_code == 200
assert status is True
assert message == "Email sent"
destroy_ctfd(app)
@patch.object(requests, "post")
def test_sendmail_with_mailgun_from_db_config(fake_post_request):
"""Does sendmail work properly with Mailgun using database configuration"""
app = create_ctfd()
with app.app_context():
app.config["MAILGUN_API_KEY"] = "key-1234567890-file-config"
app.config["MAILGUN_BASE_URL"] = "https://api.mailgun.net/v3/file.faked.com"
# db values should take precedence over file values
set_config("mailgun_api_key", "key-1234567890-db-config")
set_config("mailgun_base_url", "https://api.mailgun.net/v3/db.faked.com")
to_addr = "user@user.com"
msg = "this is a test"
sendmail(to_addr, msg)
fake_response = Mock()
fake_post_request.return_value = fake_response
fake_response.status_code = 200
status, message = sendmail(to_addr, msg)
args, kwargs = fake_post_request.call_args
assert args[0] == "https://api.mailgun.net/v3/db.faked.com/messages"
assert kwargs["auth"] == ("api", "key-1234567890-db-config")
assert kwargs["timeout"] == 1.0
assert kwargs["data"] == {
"to": ["user@user.com"],
"text": "this is a test",
"from": "CTFd <noreply@examplectf.com>",
"subject": "Message from CTFd",
}
assert fake_response.status_code == 200
assert status is True
assert message == "Email sent"
destroy_ctfd(app)
@patch("smtplib.SMTP")
def test_verify_email(mock_smtp):
"""Does verify_email send emails"""
app = create_ctfd()
with app.app_context():
set_config("mail_server", "localhost")
set_config("mail_port", 25)
set_config("mail_useauth", True)
set_config("mail_username", "username")
set_config("mail_password", "password")
set_config("verify_emails", True)
ctf_name = get_config("ctf_name")
from_addr = get_config("mailfrom_addr") or app.config.get("MAILFROM_ADDR")
from_addr = "{} <{}>".format(ctf_name, from_addr)
to_addr = "user@user.com"
urandom_value = b"\xff" * 32
with patch("os.urandom", return_value=urandom_value):
verify_email_address(to_addr)
msg = (
"Welcome to CTFd!\n\n"
"Click the following link to confirm and activate your account:\n"
"http://localhost/confirm/ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff\n\n"
"If the link is not clickable, try copying and pasting it into your browser."
)
ctf_name = get_config("ctf_name")
email_msg = EmailMessage()
email_msg.set_content(msg)
email_msg["Subject"] = "Confirm your account for {ctf_name}".format(
ctf_name=ctf_name
)
email_msg["From"] = from_addr
email_msg["To"] = to_addr
mock_smtp.return_value.send_message.assert_called()
assert str(mock_smtp.return_value.send_message.call_args[0][0]) == str(
email_msg
)
destroy_ctfd(app)
@patch("smtplib.SMTP")
def test_successful_registration_email(mock_smtp):
"""Does successful_registration_notification send emails"""
app = create_ctfd()
with app.app_context():
set_config("mail_server", "localhost")
set_config("mail_port", 25)
set_config("mail_useauth", True)
set_config("mail_username", "username")
set_config("mail_password", "password")
set_config("verify_emails", True)
ctf_name = get_config("ctf_name")
from_addr = get_config("mailfrom_addr") or app.config.get("MAILFROM_ADDR")
from_addr = "{} <{}>".format(ctf_name, from_addr)
to_addr = "user@user.com"
successful_registration_notification(to_addr)
msg = "You've successfully registered for CTFd!"
email_msg = EmailMessage()
email_msg.set_content(msg)
email_msg["Subject"] = "Successfully registered for {ctf_name}".format(
ctf_name=ctf_name
)
email_msg["From"] = from_addr
email_msg["To"] = to_addr
mock_smtp.return_value.send_message.assert_called()
assert str(mock_smtp.return_value.send_message.call_args[0][0]) == str(
email_msg
)
destroy_ctfd(app)
def test_email_whitelist():
app = create_ctfd()
with app.app_context():
set_config("domain_whitelist", "example.com")
test_cases_specific_domain = [
("john.doe@example.com", True),
("john.doe@ext.example.com", False),
("john.doe@example.io", False),
("john.doe@example.co", False),
("john.doe@ample.com", False),
("john.doe@exexample.com", False),
]
for case in test_cases_specific_domain:
email, expected = case
assert check_email_is_whitelisted(email) is expected
set_config("domain_whitelist", "*.example.com")
test_cases_wildcard_domain = [
("john.doe@ext.example.com", True),
("john.doe@.example.com", True), # this is expected behaviour
("john.doe@example.com", False),
("john.doe@example.io", False),
("john.doe@example.co", False),
("john.doe@ample.com", False),
("john.doe@exexample.com", False),
("john.doe@*example.com", False),
("john.doe@*.example.com", False),
]
for case in test_cases_wildcard_domain:
email, expected = case
assert check_email_is_whitelisted(email) is expected
set_config("domain_whitelist", "example.com, *.example.com")
test_cases_combined_domain = [
("john.doe@example.com", True),
("john.doe@ext.example.com", True),
("john.doe@.example.com", True), # this is expected behaviour
("john.doe@example.io", False),
("john.doe@example.co", False),
("john.doe@gmail.com", False),
("john.doe@ample.com", False),
("john.doe@exexample.com", False),
("john.doe@*example.com", False),
("john.doe@*.example.com", False),
]
for case in test_cases_combined_domain:
email, expected = case
assert check_email_is_whitelisted(email) is expected
set_config("domain_whitelist", "example.com, uni.acme.com, *.edu, *.edu.de")
test_cases_multiple_combined_domains = [
("john.doe@example.com", True),
("john.doe@uni.acme.com", True),
("john.doe@uni.edu", True),
("john.doe@cs.uni.edu", True),
("john.doe@mail.cs.uni.edu", True),
("john.doe@uni.edu.de", True),
("john.doe@cs.uni.edu.de", True),
("john.doe@mail.cs.uni.edu.de", True),
("john.doe@gmail.com", False),
("john.doe@ample.com", False),
("john.doe@example1.com", False),
("john.doe@1example.com", False),
("john.doe@ext.example.com", False),
("john.doe@cs.acme.com", False),
("john.doe@edu.com", False),
("john.doe@mail.uni.acme.com", False),
("john.doe@edu", False),
]
for case in test_cases_multiple_combined_domains:
email, expected = case
assert check_email_is_whitelisted(email) is expected
destroy_ctfd(app)
def test_email_blacklist():
app = create_ctfd()
with app.app_context():
set_config("domain_blacklist", "example.com")
test_cases_specific_domain = [
("john.doe@example.com", True),
("john.doe@ext.example.com", False),
("john.doe@example.io", False),
("john.doe@example.co", False),
("john.doe@ample.com", False),
("john.doe@exexample.com", False),
]
for case in test_cases_specific_domain:
email, expected = case
assert check_email_is_blacklisted(email) is expected
set_config("domain_blacklist", "*.example.com")
test_cases_wildcard_domain = [
("john.doe@ext.example.com", True),
("john.doe@.example.com", True),
("john.doe@example.com", False),
("john.doe@example.io", False),
("john.doe@example.co", False),
("john.doe@ample.com", False),
("john.doe@exexample.com", False),
("john.doe@*example.com", True),
("john.doe@*.example.com", True),
]
for case in test_cases_wildcard_domain:
email, expected = case
assert check_email_is_blacklisted(email) is expected
set_config("domain_blacklist", "example.com, *.example.com")
test_cases_combined_domain = [
("john.doe@example.com", True),
("john.doe@ext.example.com", True),
("john.doe@.example.com", True),
("john.doe@example.io", False),
("john.doe@example.co", False),
("john.doe@gmail.com", False),
("john.doe@ample.com", False),
("john.doe@exexample.com", False),
("john.doe@*example.com", True),
("john.doe@*.example.com", True),
]
for case in test_cases_combined_domain:
email, expected = case
assert check_email_is_blacklisted(email) is expected
set_config("domain_blacklist", "example.com, uni.acme.com, *.edu, *.edu.de")
test_cases_multiple_combined_domains = [
("john.doe@example.com", True),
("john.doe@uni.acme.com", True),
("john.doe@uni.edu", True),
("john.doe@cs.uni.edu", True),
("john.doe@mail.cs.uni.edu", True),
("john.doe@uni.edu.de", True),
("john.doe@cs.uni.edu.de", True),
("john.doe@mail.cs.uni.edu.de", True),
("john.doe@gmail.com", False),
("john.doe@ample.com", False),
("john.doe@example1.com", False),
("john.doe@1example.com", False),
("john.doe@ext.example.com", False),
("john.doe@cs.acme.com", False),
("john.doe@edu.com", False),
("john.doe@mail.uni.acme.com", False),
("john.doe@edu", False),
]
for case in test_cases_multiple_combined_domains:
email, expected = case
assert check_email_is_blacklisted(email) is expected
destroy_ctfd(app)
def test_confirm_links_single_use():
"""Test that confirm links are single use"""
app = create_ctfd()
with app.app_context():
set_config("mail_server", "localhost")
set_config("mail_port", 25)
set_config("mail_useauth", True)
set_config("mail_username", "username")
set_config("mail_password", "password")
set_config("verify_emails", True)
register_user(app)
client = login_as_user(app)
to_addr = "user@examplectf.com"
urandom_value = b"\xff" * 32
with patch("os.urandom", return_value=urandom_value):
verify_email_address(to_addr)
r = client.get(
"http://localhost/confirm/ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
)
user = Users.query.filter_by(email=to_addr).first()
assert user.verified is True
r = client.get(
"http://localhost/confirm/ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
)
assert (
"Your confirmation link is invalid, please generate a new one"
in r.get_data(as_text=True)
)
destroy_ctfd(app)
@patch("smtplib.SMTP")
def test_confirm_link_tokens_unique(mock_smtp):
app = create_ctfd()
with app.app_context():
set_config("mail_server", "localhost")
set_config("mail_port", 25)
set_config("mail_useauth", True)
set_config("mail_username", "username")
set_config("mail_password", "password")
set_config("verify_emails", True)
register_user(app, name="user1", email="user1@examplectf.com")
register_user(app, name="user2", email="user2@examplectf.com")
verify_email_address("user1@examplectf.com")
call1 = str(mock_smtp.return_value.send_message.call_args[0][0])
verify_email_address("user1@examplectf.com")
call2 = str(mock_smtp.return_value.send_message.call_args[0][0])
verify_email_address("user2@examplectf.com")
call3 = str(mock_smtp.return_value.send_message.call_args[0][0])
assert call1 != call2
assert call2 != call3
assert call1 != call3
destroy_ctfd(app)
def test_reset_password_links_single_use():
"""Test that reset password links are single use"""
app = create_ctfd()
with app.app_context():
register_user(app)
set_config("mail_server", "localhost")
set_config("mail_port", 25)
set_config("mail_useauth", True)
set_config("mail_username", "username")
set_config("mail_password", "password")
with app.test_client() as client:
client.get("/reset_password")
# Build reset password data
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "email": "user@examplectf.com"}
# Issue the password reset request
urandom_value = b"\xff" * 32
with patch("os.urandom", return_value=urandom_value):
client.post("/reset_password", data=data)
# Get user's original password
user = Users.query.filter_by(email="user@examplectf.com").first()
# Build the POST data
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "password": "passwordtwo"}
# Do the password reset
r = client.get(
"/reset_password/ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
)
assert r.status_code == 200
r = client.post(
"/reset_password/ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
data=data,
)
assert r.status_code == 302
# Make sure that the user's password changed
user = Users.query.filter_by(email="user@examplectf.com").first()
assert verify_password("passwordtwo", user.password)
# Attempt to re-use the password reset link
r = client.get(
"/reset_password/ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
)
assert (
"Your reset link is invalid, please generate a new one"
in r.get_data(as_text=True)
)
assert r.status_code == 200
with client.session_transaction() as sess:
data = {"nonce": sess.get("nonce"), "password": "passwordthree"}
r = client.post(
"/reset_password/ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
data=data,
)
assert r.status_code == 200
# Password should not have been changed
user = Users.query.filter_by(email="user@examplectf.com").first()
assert verify_password("passwordthree", user.password) is False
destroy_ctfd(app)
@patch("smtplib.SMTP")
def test_reset_password_link_tokens_unique(mock_smtp):
app = create_ctfd()
with app.app_context():
set_config("mail_server", "localhost")
set_config("mail_port", 25)
set_config("mail_useauth", True)
set_config("mail_username", "username")
set_config("mail_password", "password")
set_config("verify_emails", True)
register_user(app, name="user1", email="user1@examplectf.com")
register_user(app, name="user2", email="user2@examplectf.com")
forgot_password("user1@examplectf.com")
call1 = str(mock_smtp.return_value.send_message.call_args[0][0])
forgot_password("user1@examplectf.com")
call2 = str(mock_smtp.return_value.send_message.call_args[0][0])
forgot_password("user2@examplectf.com")
call3 = str(mock_smtp.return_value.send_message.call_args[0][0])
assert call1 != call2
assert call2 != call3
assert call1 != call3
destroy_ctfd(app)

View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
import string
from CTFd.utils.encoding import base64decode, base64encode, hexdecode, hexencode
def test_hexencode():
value = (
"303132333435363738396162636465666768696a6b6c6d6e6f7071727374757677"
"78797a4142434445464748494a4b4c4d4e4f505152535455565758595a21222324"
"25262728292a2b2c2d2e2f3a3b3c3d3e3f405b5c5d5e5f607b7c7d7e20090a0d0b0c"
)
assert hexencode(string.printable) == value
def test_hexdecode():
saved = (
"303132333435363738396162636465666768696a6b6c6d6e6f7071727374757677"
"78797a4142434445464748494a4b4c4d4e4f505152535455565758595a21222324"
"25262728292a2b2c2d2e2f3a3b3c3d3e3f405b5c5d5e5f607b7c7d7e20090a0d0b0c"
)
assert hexdecode(saved) == string.printable
def test_base64encode():
"""The base64encode wrapper works properly"""
assert base64encode("abc123") == "YWJjMTIz"
assert (
base64encode('"test@mailinator.com".DGxeoA.lCssU3M2QuBfohO-FtdgDQLKbU4')
== "InRlc3RAbWFpbGluYXRvci5jb20iLkRHeGVvQS5sQ3NzVTNNMlF1QmZvaE8tRnRkZ0RRTEtiVTQ"
)
assert (
base64encode("user+user@examplectf.com") == "dXNlcit1c2VyQGV4YW1wbGVjdGYuY29t"
)
assert base64encode("😆") == "8J-Yhg"
def test_base64decode():
"""The base64decode wrapper works properly"""
assert base64decode("YWJjMTIz") == "abc123"
assert (
base64decode(
"InRlc3RAbWFpbGluYXRvci5jb20iLkRHeGVvQS5sQ3NzVTNNMlF1QmZvaE8tRnRkZ0RRTEtiVTQ"
)
== '"test@mailinator.com".DGxeoA.lCssU3M2QuBfohO-FtdgDQLKbU4'
)
assert (
base64decode("dXNlcit1c2VyQGV4YW1wbGVjdGYuY29t") == "user+user@examplectf.com"
)
assert base64decode("8J-Yhg") == "😆"

250
tests/utils/test_events.py Normal file
View File

@@ -0,0 +1,250 @@
from collections import defaultdict
from queue import Queue
from unittest.mock import patch
from redis.exceptions import ConnectionError
from CTFd.config import TestingConfig
from CTFd.utils.events import EventManager, RedisEventManager, ServerSentEvent
from tests.helpers import create_ctfd, destroy_ctfd, login_as_user, register_user
def test_event_manager_installed():
"""Test that EventManager is installed on the Flask app"""
app = create_ctfd()
assert type(app.events_manager) == EventManager
destroy_ctfd(app)
def test_event_manager_subscription():
"""Test that EventManager subscribing works"""
with patch.object(Queue, "get") as fake_queue:
saved_data = {
"user_id": None,
"title": "asdf",
"content": "asdf",
"team_id": None,
"user": None,
"team": None,
"date": "2019-01-28T01:20:46.017649+00:00",
"id": 10,
}
saved_event = {"type": "notification", "data": saved_data}
fake_queue.return_value = saved_event
event_manager = EventManager()
events = event_manager.subscribe()
message = next(events)
assert isinstance(message, ServerSentEvent)
assert message.to_dict() == {"data": "ping", "type": "ping"}
assert message.__str__().startswith("event:ping")
assert len(event_manager.clients) == 1
message = next(events)
assert isinstance(message, ServerSentEvent)
assert message.to_dict() == saved_event
assert message.__str__().startswith("event:notification\ndata:")
assert len(event_manager.clients) == 1
def test_event_manager_publish():
"""Test that EventManager publishing to clients works"""
saved_data = {
"user_id": None,
"title": "asdf",
"content": "asdf",
"team_id": None,
"user": None,
"team": None,
"date": "2019-01-28T01:20:46.017649+00:00",
"id": 10,
}
event_manager = EventManager()
q = defaultdict(Queue)
event_manager.clients[id(q)] = q
event_manager.publish(data=saved_data, type="notification", channel="ctf")
event = event_manager.clients[id(q)]["ctf"].get()
event = ServerSentEvent(**event)
assert event.data == saved_data
def test_event_endpoint_is_event_stream():
"""Test that the /events endpoint is text/event-stream"""
app = create_ctfd()
with patch.object(Queue, "get") as fake_queue:
saved_data = {
"user_id": None,
"title": "asdf",
"content": "asdf",
"team_id": None,
"user": None,
"team": None,
"date": "2019-01-28T01:20:46.017649+00:00",
"id": 10,
}
saved_event = {"type": "notification", "data": saved_data}
fake_queue.return_value = saved_event
with app.app_context():
register_user(app)
with login_as_user(app) as client:
r = client.get("/events")
assert "text/event-stream" in r.headers["Content-Type"]
destroy_ctfd(app)
def test_redis_event_manager_installed():
"""Test that RedisEventManager is installed on the Flask app"""
class RedisConfig(TestingConfig):
REDIS_URL = "redis://localhost:6379/1"
CACHE_REDIS_URL = "redis://localhost:6379/1"
CACHE_TYPE = "redis"
try:
app = create_ctfd(config=RedisConfig)
except ConnectionError:
print("Failed to connect to redis. Skipping test.")
else:
with app.app_context():
assert isinstance(app.events_manager, RedisEventManager)
destroy_ctfd(app)
def test_redis_event_manager_subscription():
"""Test that RedisEventManager subscribing works."""
class RedisConfig(TestingConfig):
REDIS_URL = "redis://localhost:6379/2"
CACHE_REDIS_URL = "redis://localhost:6379/2"
CACHE_TYPE = "redis"
try:
app = create_ctfd(config=RedisConfig)
except ConnectionError:
print("Failed to connect to redis. Skipping test.")
else:
with app.app_context():
saved_data = {
"user_id": None,
"title": "asdf",
"content": "asdf",
"team_id": None,
"user": None,
"team": None,
"date": "2019-01-28T01:20:46.017649+00:00",
"id": 10,
}
saved_event = {"type": "notification", "data": saved_data}
with patch.object(Queue, "get") as fake_queue:
fake_queue.return_value = saved_event
event_manager = RedisEventManager()
events = event_manager.subscribe()
message = next(events)
assert isinstance(message, ServerSentEvent)
assert message.to_dict() == {"data": "ping", "type": "ping"}
assert message.__str__().startswith("event:ping")
message = next(events)
assert isinstance(message, ServerSentEvent)
assert message.to_dict() == saved_event
assert message.__str__().startswith("event:notification\ndata:")
destroy_ctfd(app)
def test_redis_event_manager_publish():
"""Test that RedisEventManager publishing to clients works."""
class RedisConfig(TestingConfig):
REDIS_URL = "redis://localhost:6379/3"
CACHE_REDIS_URL = "redis://localhost:6379/3"
CACHE_TYPE = "redis"
try:
app = create_ctfd(config=RedisConfig)
except ConnectionError:
print("Failed to connect to redis. Skipping test.")
else:
with app.app_context():
saved_data = {
"user_id": None,
"title": "asdf",
"content": "asdf",
"team_id": None,
"user": None,
"team": None,
"date": "2019-01-28T01:20:46.017649+00:00",
"id": 10,
}
event_manager = RedisEventManager()
event_manager.publish(data=saved_data, type="notification", channel="ctf")
destroy_ctfd(app)
def test_redis_event_manager_listen():
"""Test that RedisEventManager listening pubsub works."""
# This test is nob currently working properly
# This test is sort of incomplete b/c we aren't also subscribing
# I wasnt able to get listening and subscribing to work at the same time
# But the code does work under gunicorn and serve.py
try:
# import importlib
# from gevent.monkey import patch_time, patch_socket
# from gevent import Timeout
# patch_time()
# patch_socket()
class RedisConfig(TestingConfig):
REDIS_URL = "redis://localhost:6379/4"
CACHE_REDIS_URL = "redis://localhost:6379/4"
CACHE_TYPE = "redis"
try:
app = create_ctfd(config=RedisConfig)
except ConnectionError:
print("Failed to connect to redis. Skipping test.")
else:
with app.app_context():
# saved_event = {
# "data": {
# "team_id": None,
# "user_id": None,
# "content": "asdf",
# "title": "asdf",
# "id": 1,
# "team": None,
# "user": None,
# "date": "2020-08-31T23:57:27.193081+00:00",
# "type": "toast",
# "sound": None,
# },
# "type": "notification",
# }
event_manager = RedisEventManager()
# def disable_retry(f, *args, **kwargs):
# return f()
# with patch("tenacity.retry", side_effect=disable_retry):
# with Timeout(10):
# event_manager.listen()
event_manager.listen()
# event_manager.publish(
# data=saved_event["data"], type="notification", channel="ctf"
# )
destroy_ctfd(app)
finally:
pass
# import socket
# import time
# importlib.reload(socket)
# importlib.reload(time)

104
tests/utils/test_exports.py Normal file
View File

@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
import json
import os
import zipfile
from CTFd.models import Challenges, Flags, Teams, Users
from CTFd.utils import text_type
from CTFd.utils.exports import export_ctf, import_ctf
from tests.helpers import (
create_ctfd,
destroy_ctfd,
gen_challenge,
gen_flag,
gen_hint,
gen_team,
gen_user,
login_as_user,
register_user,
)
def test_export_ctf():
"""Test that CTFd can export the database"""
app = create_ctfd()
if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("sqlite"):
with app.app_context():
register_user(app)
chal1 = gen_challenge(app.db, name=text_type("🐺"))
gen_challenge(
app.db, name=text_type("🐺"), requirements={"prerequisites": [1]}
)
chal_id = chal1.id
gen_hint(app.db, chal_id)
client = login_as_user(app)
with client.session_transaction():
data = {"target": 1, "type": "hints"}
r = client.post("/api/v1/unlocks", json=data)
output = r.get_data(as_text=True)
json.loads(output)
app.db.session.commit()
backup = export_ctf()
with open("export.test_export_ctf.zip", "wb") as f:
f.write(backup.read())
export = zipfile.ZipFile("export.test_export_ctf.zip", "r")
data = json.loads(export.read("db/challenges.json"))
assert data["results"][1]["requirements"] == {"prerequisites": [1]}
os.remove("export.test_export_ctf.zip")
destroy_ctfd(app)
def test_import_ctf():
"""Test that CTFd can import a CTF"""
app = create_ctfd()
if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("sqlite"):
with app.app_context():
base_user = "user"
for x in range(10):
user = base_user + str(x)
user_email = user + "@examplectf.com"
gen_user(app.db, name=user, email=user_email)
base_team = "team"
for x in range(5):
team = base_team + str(x)
team_email = team + "@examplectf.com"
gen_team(app.db, name=team, email=team_email)
for x in range(9):
chal = gen_challenge(app.db, name="chal_name{}".format(x))
gen_flag(app.db, challenge_id=chal.id, content="flag")
chal = gen_challenge(
app.db, name="chal_name10", requirements={"prerequisites": [1]}
)
gen_flag(app.db, challenge_id=chal.id, content="flag")
app.db.session.commit()
backup = export_ctf()
with open("export.test_import_ctf.zip", "wb") as f:
f.write(backup.read())
destroy_ctfd(app)
app = create_ctfd()
# TODO: These databases should work but they don't...
if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("sqlite"):
with app.app_context():
import_ctf("export.test_import_ctf.zip")
if not app.config.get("SQLALCHEMY_DATABASE_URI").startswith("postgres"):
# TODO: Dig deeper into why Postgres fails here
assert Users.query.count() == 31
assert Teams.query.count() == 5
assert Challenges.query.count() == 10
assert Flags.query.count() == 10
chal = Challenges.query.filter_by(name="chal_name10").first()
assert chal.requirements == {"prerequisites": [1]}
destroy_ctfd(app)

View File

@@ -0,0 +1,13 @@
from CTFd.utils.formatters import safe_format
def test_safe_format():
assert safe_format("Message from {ctf_name}", ctf_name="CTF") == "Message from CTF"
assert (
safe_format("Message from {{ ctf_name }}", ctf_name="CTF") == "Message from CTF"
)
assert safe_format("{{ ctf_name }} {{ctf_name}}", ctf_name="CTF") == "CTF CTF"
assert (
safe_format("{ ctf_name } {ctf_name} {asdf}", ctf_name="CTF")
== "CTF CTF {asdf}"
)

View File

@@ -0,0 +1,19 @@
from CTFd.utils.humanize.numbers import ordinalize
def test_ordinalize():
tests = {
1: "1st",
2: "2nd",
3: "3rd",
4: "4th",
11: "11th",
12: "12th",
13: "13th",
101: "101st",
102: "102nd",
103: "103rd",
111: "111th",
}
for t, v in tests.items():
assert ordinalize(t) == v

View File

@@ -0,0 +1,12 @@
from CTFd.utils import markdown
def test_markdown():
"""
Test that our markdown function renders properly
"""
# Allow raw HTML / potentially unsafe HTML
assert (
markdown("<iframe src='https://example.com'></iframe>").strip()
== "<iframe src='https://example.com'></iframe>"
)

View File

@@ -0,0 +1,19 @@
from CTFd.utils.crypto import hash_password, sha256, verify_password
def test_hash_password():
assert hash_password("asdf").startswith("$bcrypt-sha256")
def test_verify_password():
assert verify_password(
"asdf",
"$bcrypt-sha256$2b,12$I0CNXRkGD2Bi/lbC4vZ7Y.$1WoilsadKpOjXa/be9x3dyu7p.mslZ6",
)
def test_sha256():
assert (
sha256("asdf")
== "f0e4c2f76c58916ec258f246851bea091d14d4247a2fc3e18694461b1816e13b"
)

View File

@@ -0,0 +1,57 @@
from CTFd.plugins import register_plugin_script
from CTFd.utils.plugins import override_template
from tests.helpers import create_ctfd, destroy_ctfd, login_as_user
def test_override_template():
"""Does override_template work properly for regular themes"""
app = create_ctfd()
with app.app_context():
override_template("login.html", "LOGIN OVERRIDE")
with app.test_client() as client:
r = client.get("/login")
assert r.status_code == 200
output = r.get_data(as_text=True)
assert "LOGIN OVERRIDE" in output
destroy_ctfd(app)
def test_admin_override_template():
"""Does override_template work properly for the admin panel"""
app = create_ctfd()
with app.app_context():
override_template("admin/users/user.html", "ADMIN TEAM OVERRIDE")
client = login_as_user(app, name="admin", password="password")
r = client.get("/admin/users/1")
assert r.status_code == 200
output = r.get_data(as_text=True)
assert "ADMIN TEAM OVERRIDE" in output
destroy_ctfd(app)
def test_register_plugin_script():
"""Test that register_plugin_script adds script paths to the core theme"""
app = create_ctfd()
with app.app_context():
register_plugin_script("/fake/script/path.js")
register_plugin_script("http://examplectf.com/fake/script/path.js")
with app.test_client() as client:
r = client.get("/")
output = r.get_data(as_text=True)
assert "/fake/script/path.js" in output
assert "http://examplectf.com/fake/script/path.js" in output
destroy_ctfd(app)
def test_register_plugin_stylesheet():
"""Test that register_plugin_stylesheet adds stylesheet paths to the core theme"""
app = create_ctfd()
with app.app_context():
register_plugin_script("/fake/stylesheet/path.css")
register_plugin_script("http://examplectf.com/fake/stylesheet/path.css")
with app.test_client() as client:
r = client.get("/")
output = r.get_data(as_text=True)
assert "/fake/stylesheet/path.css" in output
assert "http://examplectf.com/fake/stylesheet/path.css" in output
destroy_ctfd(app)

View File

@@ -0,0 +1,24 @@
from tests.helpers import create_ctfd, destroy_ctfd, register_user
def test_ratelimit_on_auth():
"""Test that ratelimiting function works properly"""
app = create_ctfd()
with app.app_context():
register_user(app)
with app.test_client() as client:
r = client.get("/login")
with client.session_transaction() as sess:
data = {
"name": "user",
"password": "wrong_password",
"nonce": sess.get("nonce"),
}
for _ in range(10):
r = client.post("/login", data=data)
assert r.status_code == 200
for _ in range(5):
r = client.post("/login", data=data)
assert r.status_code == 429
destroy_ctfd(app)

View File

@@ -0,0 +1,364 @@
from collections import namedtuple
from CTFd.utils.security.sanitize import sanitize_html
Case = namedtuple("Case", ["input", "expected"])
def test_sanitize_html_empty():
"""Test sanitize_html with empty input"""
assert sanitize_html("") == ""
def test_sanitize_html_basic_tags():
"""Test that basic HTML tags are preserved"""
cases = [
Case("<p>Hello World</p>", "<p>Hello World</p>"),
Case("<div>Content</div>", "<div>Content</div>"),
Case("<span>Text</span>", "<span>Text</span>"),
Case("<strong>Bold</strong>", "<strong>Bold</strong>"),
Case("<em>Italic</em>", "<em>Italic</em>"),
Case("<h1>Header</h1>", "<h1>Header</h1>"),
Case("<h2>Header</h2>", "<h2>Header</h2>"),
Case("<h3>Header</h3>", "<h3>Header</h3>"),
Case(
"<ul><li>Item 1</li><li>Item 2</li></ul>",
"<ul><li>Item 1</li><li>Item 2</li></ul>",
),
Case(
"<ol><li>Item 1</li><li>Item 2</li></ol>",
"<ol><li>Item 1</li><li>Item 2</li></ol>",
),
]
for case in cases:
assert sanitize_html(case.input) == case.expected
def test_sanitize_html_links():
"""Test that links are sanitized with proper rel attributes"""
cases = [
Case(
'<a href="https://example.com">Link</a>',
'<a href="https://example.com" rel="noopener noreferrer nofollow">Link</a>',
),
Case(
'<a href="http://example.com">Link</a>',
'<a href="http://example.com" rel="noopener noreferrer nofollow">Link</a>',
),
Case(
'<a href="//example.com">Link</a>',
'<a href="//example.com" rel="noopener noreferrer nofollow">Link</a>',
),
Case(
'<a href="/internal/path">Link</a>',
'<a href="/internal/path" rel="noopener noreferrer nofollow">Link</a>',
),
Case(
'<a href="mailto:test@example.com">Email</a>',
'<a href="mailto:test@example.com" rel="noopener noreferrer nofollow">Email</a>',
),
Case(
'<a href="tel:+1234567890">Phone</a>',
'<a href="tel:+1234567890" rel="noopener noreferrer nofollow">Phone</a>',
),
Case(
'<a href="javascript:alert(1)">Evil</a>',
'<a rel="noopener noreferrer nofollow">Evil</a>',
),
Case(
'<a href="#anchor">Anchor</a>',
'<a href="#anchor" rel="noopener noreferrer nofollow">Anchor</a>',
),
Case(
'<a href="?q=1">Query</a>',
'<a href="?q=1" rel="noopener noreferrer nofollow">Query</a>',
),
Case(
'<a href="?q=1&r=2">Query</a>',
'<a href="?q=1&amp;r=2" rel="noopener noreferrer nofollow">Query</a>',
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_images():
"""Test that images are preserved with allowed attributes"""
cases = [
Case(
'<img src="https://example.com/image.jpg" alt="Test Image">',
'<img src="https://example.com/image.jpg" alt="Test Image">',
),
Case(
'<img src="image.jpg" alt="Local Image" width="100" height="100">',
'<img src="image.jpg" alt="Local Image" width="100" height="100">',
),
Case(
'<img src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==" alt="Red dot">',
'<img src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==" alt="Red dot">',
),
Case(
'<img src="image.gif?height=500&width=500" alt="Animated">',
'<img src="image.gif?height=500&amp;width=500" alt="Animated">',
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_dangerous_content():
"""Test that dangerous content is removed or sanitized"""
cases = [
Case('<script>alert("xss")</script>', ""),
Case('<script src="evil.js"></script>', ""),
Case('<object data="evil.swf"></object>', ""),
Case('<embed src="evil.swf">', ""),
Case('<link rel="stylesheet" href="evil.css">', ""),
Case("<div onclick=\"alert('xss')\">Content</div>", "<div>Content</div>"),
Case('<img src="image.jpg" onload="alert(\'xss\')">', '<img src="image.jpg">'),
Case('<img src="image.jpg" onerror="alert(\'xss\')">', '<img src="image.jpg">'),
Case("<body onload=\"alert('xss')\">Content</body>", "Content"),
Case('<iframe src="javascript:alert(1)"></iframe>', "<iframe></iframe>"),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_forms():
"""Test that form elements are preserved"""
cases = [
Case(
'<form method="post" action="/submit"><input type="text" name="username"></form>',
'<form method="post" action="/submit"><input type="text" name="username"></form>',
),
Case(
'<textarea name="message" placeholder="Enter message"></textarea>',
'<textarea name="message" placeholder="Enter message"></textarea>',
),
Case(
'<select name="option"><option value="1">Option 1</option></select>',
'<select name="option"><option value="1">Option 1</option></select>',
),
Case(
'<button type="submit">Submit</button>',
'<button type="submit">Submit</button>',
),
Case(
'<label for="username">Username:</label>',
'<label for="username">Username:</label>',
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_media():
"""Test that media elements are preserved"""
cases = [
Case(
'<video controls src="video.mp4" width="640" height="480"></video>',
'<video controls="" src="video.mp4" width="640" height="480"></video>',
),
Case(
'<audio controls src="audio.mp3"></audio>',
'<audio controls="" src="audio.mp3"></audio>',
),
Case(
'<iframe src="https://example.com" width="600" height="400" frameborder="0"></iframe>',
'<iframe src="https://example.com" width="600" height="400" frameborder="0"></iframe>',
),
Case(
'<source src="video.mp4" type="video/mp4">',
'<source src="video.mp4" type="video/mp4">',
),
]
for case in cases:
result = sanitize_html(case.input)
print(result)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_tables():
"""Test that table elements are preserved"""
cases = [
Case(
"<table><tr><th>Header</th><td>Data</td></tr></table>",
"<table><tbody><tr><th>Header</th><td>Data</td></tr></tbody></table>",
),
Case(
'<table border="1" cellpadding="5" cellspacing="0"><tbody><tr><td>Cell</td></tr></tbody></table>',
'<table border="1" cellpadding="5" cellspacing="0"><tbody><tr><td>Cell</td></tr></tbody></table>',
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_attributes():
"""Test that allowed attributes are preserved and dangerous ones removed"""
cases = [
Case(
'<div class="container" id="main" style="color: red;">Content</div>',
'<div class="container" id="main" style="color: red;">Content</div>',
),
Case(
'<div data-toggle="modal" data-target="#myModal">Modal</div>',
'<div data-toggle="modal" data-target="#myModal">Modal</div>',
),
Case(
'<button aria-label="Close" aria-expanded="false">Button</button>',
'<button aria-label="Close" aria-expanded="false">Button</button>',
),
Case(
'<img src="image.jpg" title="Image Title" alt="Alt Text">',
'<img src="image.jpg" title="Image Title" alt="Alt Text">',
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_comments():
"""Test that HTML comments are preserved"""
cases = [
Case(
"<!-- This is a comment --><p>Content</p>",
"<!-- This is a comment --><p>Content</p>",
),
Case(
"<div>Before<!-- comment -->After</div>",
"<div>Before<!-- comment -->After</div>",
),
Case(
"<!--&gt;<img src=x onerror=alert()&gt;>",
"<!--&gt;<img src=x onerror=alert()&gt;>-->",
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_href_sanitization():
"""Test that href attributes are properly sanitized"""
cases = [
Case(
'abc<a href="https://abc&quot;&gt;<script&gt;alert(1)<&#x2f;script/">CLICK</a>',
'abc<a rel="noopener noreferrer nofollow">CLICK</a>',
),
Case(
'<a href="https://abc&quot;&gt;<script&gt;alert(1)<&#x2f;script/">Link</a>',
'<a rel="noopener noreferrer nofollow">Link</a>',
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_malformed():
"""Test sanitize_html with malformed HTML"""
cases = [
Case("<p>Unclosed paragraph", "<p>Unclosed paragraph</p>"),
Case(
"<strong><em>Improperly nested</strong></em>",
"<strong><em>Improperly nested</em></strong>",
),
Case("Text with & ampersand", "Text with &amp; ampersand"),
Case("Text with < less than", "Text with &lt; less than"),
Case("Text with > greater than", "Text with &gt; greater than"),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_whitespace():
"""Test that whitespace is preserved correctly"""
cases = [
Case("Hi.\n", "Hi.\n"),
Case("\t\n \n\t", "\t\n \n\t"),
Case(" <p> Spaced content </p> ", " <p> Spaced content </p> "),
Case("<pre> Code with spaces </pre>", "<pre> Code with spaces </pre>"),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"
def test_sanitize_html_complex_content():
"""Test sanitize_html with complex mixed content"""
cases = [
Case(
"""<div class="container">
<h1>Welcome to CTF</h1>
<p>This is a <strong>challenge description</strong> with <em>formatting</em>.</p>
<ul>
<li>Hint 1: Look at the source</li>
<li>Hint 2: Check the headers</li>
</ul>
<p>Visit <a href="https://example.com">this link</a> for more info.</p>
<img src="https://example.com/flag.png" alt="Flag" width="200">
<!-- This is a comment -->
<pre><code>print("Hello World")</code></pre>
</div>""",
"""<div class="container">
<h1>Welcome to CTF</h1>
<p>This is a <strong>challenge description</strong> with <em>formatting</em>.</p>
<ul>
<li>Hint 1: Look at the source</li>
<li>Hint 2: Check the headers</li>
</ul>
<p>Visit <a href="https://example.com" rel="noopener noreferrer nofollow">this link</a> for more info.</p>
<img src="https://example.com/flag.png" alt="Flag" width="200">
<!-- This is a comment -->
<pre><code>print("Hello World")</code></pre>
</div>""",
),
]
for case in cases:
result = sanitize_html(case.input)
assert (
result == case.expected
), f"Input: {case.input}, Expected: {case.expected}, Got: {result}"

View File

@@ -0,0 +1,97 @@
from unittest.mock import Mock, patch
from uuid import UUID
from tests.helpers import create_ctfd, destroy_ctfd, login_as_user, register_user
def test_sessions_set_httponly():
app = create_ctfd()
with app.app_context():
with app.test_client() as client:
r = client.get("/")
cookie = dict(r.headers)["Set-Cookie"]
assert "HttpOnly;" in cookie
destroy_ctfd(app)
def test_sessions_set_samesite():
app = create_ctfd()
with app.app_context():
with app.test_client() as client:
r = client.get("/")
cookie = dict(r.headers)["Set-Cookie"]
assert "SameSite=" in cookie
destroy_ctfd(app)
def test_session_invalidation_on_admin_password_change():
app = create_ctfd()
with app.app_context():
register_user(app)
with login_as_user(app, name="admin") as admin, login_as_user(app) as user:
r = user.get("/settings")
assert r.status_code == 200
r = admin.patch("/api/v1/users/2", json={"password": "password2"})
assert r.status_code == 200
r = user.get("/settings")
# User's password was changed
# They should be logged out
assert r.location.startswith("/login")
assert r.status_code == 302
destroy_ctfd(app)
def test_session_invalidation_on_user_password_change():
app = create_ctfd()
with app.app_context():
register_user(app)
with login_as_user(app) as user:
r = user.get("/settings")
assert r.status_code == 200
data = {"confirm": "password", "password": "new_password"}
r = user.patch("/api/v1/users/me", json=data)
assert r.status_code == 200
r = user.get("/settings")
# User initiated their own password change
# They should not be logged out
assert r.status_code == 200
destroy_ctfd(app)
# @patch.object(uuid, 'uuid4', side_effect=TEST_UUIDS)
# @patch.object(uuid, 'uuid4')
def test_session_with_duplicate_session_id():
app = create_ctfd()
with app.app_context():
register_user(app)
register_user(app, name="user1", email="user1@examplectf.com")
TEST_UUIDS = [
# First user login successful
UUID("2d0ac3a8-b956-491a-9f53-d27cd33f2529"),
UUID("85e61378-5bc4-4cc8-a37e-b03270b7b172"),
# Second user gets a unique UUID then a duplicated one
UUID("c47c907f-d508-4f23-a28a-a1af1e9d3f27"),
UUID("85e61378-5bc4-4cc8-a37e-b03270b7b172"),
UUID("85e61378-5bc4-4cc8-a37e-b03270b7b172"),
UUID("85e61378-5bc4-4cc8-a37e-b03270b7b172"),
UUID("85e61378-5bc4-4cc8-a37e-b03270b7b172"),
UUID("85e61378-5bc4-4cc8-a37e-b03270b7b172"),
# Second user should finally receive a unique UUID
UUID("a00aff35-a12e-465a-8747-e18f78f60b13"),
UUID("da876038-7602-4bb0-88b8-f7104094219f"),
]
uuid_mock = Mock(side_effect=TEST_UUIDS)
with patch(target="CTFd.utils.sessions.uuid4", new=uuid_mock):
login_as_user(app)
with patch(target="CTFd.utils.sessions.uuid4", new=uuid_mock):
login_as_user(app, name="user1")
destroy_ctfd(app)

104
tests/utils/test_updates.py Normal file
View File

@@ -0,0 +1,104 @@
from unittest.mock import Mock, patch
import requests
from CTFd.utils import get_config, set_config
from CTFd.utils.updates import update_check
from tests.helpers import create_ctfd, destroy_ctfd, login_as_user
def test_update_check_is_called():
"""Update checks happen on start"""
app = create_ctfd()
with app.app_context():
assert get_config("version_latest") is None
@patch.object(requests, "get")
def test_update_check_identifies_update(fake_get_request):
"""Update checks properly identify new versions"""
app = create_ctfd()
with app.app_context():
app.config["UPDATE_CHECK"] = True
fake_response = Mock()
fake_get_request.return_value = fake_response
fake_response.json = lambda: {
"resource": {
"download_url": "https://api.github.com/repos/CTFd/CTFd/zipball/9.9.9",
"html_url": "https://github.com/CTFd/CTFd/releases/tag/9.9.9",
"id": 12,
"latest": True,
"next": 1542212248,
"prerelease": False,
"published_at": "Wed, 25 Oct 2017 19:39:42 -0000",
"tag": "9.9.9",
}
}
update_check()
assert (
get_config("version_latest")
== "https://github.com/CTFd/CTFd/releases/tag/9.9.9"
)
assert get_config("next_update_check") == 1542212248
destroy_ctfd(app)
def test_update_check_notifies_user():
"""If an update is available, admin users are notified in the panel"""
app = create_ctfd()
with app.app_context():
app.config["UPDATE_CHECK"] = True
set_config("version_latest", "https://github.com/CTFd/CTFd/releases/tag/9.9.9")
client = login_as_user(app, name="admin", password="password")
r = client.get("/admin/config")
assert r.status_code == 200
response = r.get_data(as_text=True)
assert "https://github.com/CTFd/CTFd/releases/tag/9.9.9" in response
destroy_ctfd(app)
@patch.object(requests, "post")
def test_update_check_ignores_downgrades(fake_post_request):
"""Update checks do nothing on old or same versions"""
app = create_ctfd()
with app.app_context():
app.config["UPDATE_CHECK"] = True
fake_response = Mock()
fake_post_request.return_value = fake_response
fake_response.json = lambda: {
"resource": {
"html_url": "https://github.com/CTFd/CTFd/releases/tag/0.0.1",
"download_url": "https://api.github.com/repos/CTFd/CTFd/zipball/0.0.1",
"published_at": "Wed, 25 Oct 2017 19:39:42 -0000",
"tag": "0.0.1",
"prerelease": False,
"id": 6,
"latest": True,
}
}
update_check()
assert get_config("version_latest") is None
fake_response = Mock()
fake_post_request.return_value = fake_response
fake_response.json = lambda: {
"resource": {
"html_url": "https://github.com/CTFd/CTFd/releases/tag/{}".format(
app.VERSION
),
"download_url": "https://api.github.com/repos/CTFd/CTFd/zipball/{}".format(
app.VERSION
),
"published_at": "Wed, 25 Oct 2017 19:39:42 -0000",
"tag": "{}".format(app.VERSION),
"prerelease": False,
"id": 6,
"latest": True,
}
}
update_check()
assert get_config("version_latest") is None
destroy_ctfd(app)

View File

@@ -0,0 +1,97 @@
import os
from io import BytesIO
import boto3
from moto import mock_s3
from CTFd.utils.uploads import S3Uploader, rmdir
from tests.helpers import create_ctfd, destroy_ctfd
@mock_s3
def test_s3_uploader():
conn = boto3.resource("s3", region_name="test-region")
conn.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "test-region"}
)
app = create_ctfd()
with app.app_context():
app.config["UPLOAD_PROVIDER"] = "s3"
app.config["AWS_ACCESS_KEY_ID"] = "AKIAIOSFODNN7EXAMPLE"
app.config["AWS_SECRET_ACCESS_KEY"] = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
app.config["AWS_S3_BUCKET"] = "bucket"
app.config["AWS_S3_REGION"] = "test-region"
uploader = S3Uploader()
assert uploader.s3
assert uploader.bucket == "bucket"
fake_file = BytesIO("fakedfile".encode())
path = uploader.upload(fake_file, "fake_file.txt")
assert "fake_file.txt" in uploader.download(path).location
destroy_ctfd(app)
@mock_s3
def test_s3_uploader_custom_prefix():
conn = boto3.resource("s3", region_name="test-region")
conn.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "test-region"}
)
app = create_ctfd()
with app.app_context():
app.config["UPLOAD_PROVIDER"] = "s3"
app.config["AWS_ACCESS_KEY_ID"] = "AKIAIOSFODNN7EXAMPLE"
app.config["AWS_SECRET_ACCESS_KEY"] = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
app.config["AWS_S3_BUCKET"] = "bucket"
app.config["AWS_S3_REGION"] = "test-region"
app.config["AWS_S3_CUSTOM_PREFIX"] = "prefix"
uploader = S3Uploader()
assert uploader.s3
assert uploader.bucket == "bucket"
fake_file = BytesIO("fakedfile".encode())
path = uploader.upload(fake_file, "fake_file.txt")
assert "fake_file.txt" in uploader.download(path).location
fake_file2 = BytesIO("fakedfile".encode())
path2 = uploader.upload(fake_file2, "fake_file.txt", "path")
assert "/prefix/path/fake_file.txt" in uploader.download(path2).location
destroy_ctfd(app)
@mock_s3
def test_s3_sync():
conn = boto3.resource("s3", region_name="test-region")
conn.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "test-region"}
)
app = create_ctfd()
with app.app_context():
app.config["UPLOAD_PROVIDER"] = "s3"
app.config["AWS_ACCESS_KEY_ID"] = "AKIAIOSFODNN7EXAMPLE"
app.config["AWS_SECRET_ACCESS_KEY"] = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
app.config["AWS_S3_BUCKET"] = "bucket"
app.config["AWS_S3_REGION"] = "test-region"
uploader = S3Uploader()
uploader.sync()
fake_file = BytesIO("fakedfile".encode())
path = uploader.upload(fake_file, "fake_file.txt")
full_path = os.path.join(app.config["UPLOAD_FOLDER"], path)
try:
uploader.sync()
with open(full_path) as f:
assert f.read() == "fakedfile"
finally:
rmdir(os.path.dirname(full_path))
destroy_ctfd(app)

View File

@@ -0,0 +1,36 @@
from marshmallow import ValidationError
from CTFd.utils.validators import validate_country_code, validate_email
def test_validate_country_code():
assert validate_country_code("") is None
# TODO: This looks poor, when everything moves to pytest we should remove exception catches like this.
try:
validate_country_code("ZZ")
except ValidationError:
pass
def test_validate_email():
"""Test that the check_email_format() works properly"""
assert validate_email("user@examplectf.com") is True
assert validate_email("user+plus@gmail.com") is True
assert validate_email("user.period1234@gmail.com") is True
assert validate_email("user.period1234@b.c") is True
assert validate_email("user.period1234@b") is False
assert validate_email("no.ampersand") is False
assert validate_email("user@") is False
assert validate_email("@examplectf.com") is False
assert validate_email("user.io@ctfd") is False
assert validate_email("user\\@ctfd") is False
for invalid_email in [
"user.@examplectf.com",
".user@examplectf.com",
"user@ctfd..io",
]:
try:
assert validate_email(invalid_email) is False
except AssertionError:
print(invalid_email, "did not pass validation")