Added password flow authentication.
This commit is contained in:
parent
8e208e2f99
commit
6d282b6e83
|
@ -1,4 +1,7 @@
|
|||
HOST=<IP or hostname>
|
||||
PORT=<port number>
|
||||
ENV=<dev or prod>
|
||||
SQLALCHEMY_DATABASE_URL=<file location>
|
||||
SQLALCHEMY_DATABASE_URL=<file location>
|
||||
SECRET_KEY=<JWT secret>
|
||||
ALGORITHM=<JWT algorithm>
|
||||
TOKEN_EXPIRE_DAYS=<when token expires>
|
|
@ -9,3 +9,7 @@ PORT = int(os.getenv('PORT'))
|
|||
ENV = os.getenv('ENV')
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = os.getenv('SQLALCHEMY_DATABASE_URL')
|
||||
|
||||
SECRET_KEY = os.getenv('SECRET_KEY')
|
||||
ALGORITHM = os.getenv('ALGORITHM')
|
||||
TOKEN_EXPIRE_DAYS = int(os.getenv('TOKEN_EXPIRE_DAYS'))
|
||||
|
|
|
@ -12,3 +12,11 @@ engine = create_engine(
|
|||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
|
|
@ -4,6 +4,14 @@ from sqlalchemy.orm import relationship
|
|||
from .database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String, unique=True, index=True)
|
||||
hashed_password = Column(String)
|
||||
|
||||
|
||||
class Owner(Base):
|
||||
__tablename__ = 'owners'
|
||||
|
||||
|
|
|
@ -36,3 +36,19 @@ class Owner(OwnerBase):
|
|||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
class UserIn(UserBase):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
id: int
|
||||
hashed_password: str
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
|
43
main.py
43
main.py
|
@ -2,54 +2,69 @@ from typing import List
|
|||
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from secure.authenticate import oauth2_scheme, get_token
|
||||
import config
|
||||
from db import crud, models, schemas
|
||||
from db.database import SessionLocal, engine
|
||||
from db.database import engine, get_db
|
||||
|
||||
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
@app.post('/token')
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)):
|
||||
return get_token(db, form_data)
|
||||
|
||||
|
||||
@app.post('/owners/', response_model=schemas.Owner)
|
||||
def create_owner(owner: schemas.OwnerIn, db: Session = Depends(get_db)):
|
||||
def create_owner(owner: schemas.OwnerIn,
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)):
|
||||
db_owner = crud.get_owner_by_name(db, name=owner.name)
|
||||
if db_owner:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Owner already exists.')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Owner already exists.')
|
||||
return crud.create_owner(db=db, owner=owner)
|
||||
|
||||
|
||||
@app.get('/owners/', response_model=List[schemas.Owner])
|
||||
def read_owners(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
def read_owners(skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)):
|
||||
owners = crud.get_owners(db, skip=skip, limit=limit)
|
||||
return owners
|
||||
|
||||
|
||||
@app.get('/owners/{owner_id}', response_model=schemas.Owner)
|
||||
def read_owner(owner_id: int, db: Session = Depends(get_db)):
|
||||
def read_owner(owner_id: int, db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)):
|
||||
db_owner = crud.get_owner(db, owner_id=owner_id)
|
||||
if not db_owner:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail='Owner not found.')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Owner not found.')
|
||||
return db_owner
|
||||
|
||||
|
||||
@app.post('/owners/{owner_id}/readings/', response_model=schemas.Reading)
|
||||
def create_reading_for_owner(owner_id: int, reading: schemas.ReadingIn, db: Session = Depends(get_db)):
|
||||
def create_reading_for_owner(owner_id: int,
|
||||
reading: schemas.ReadingIn,
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)):
|
||||
return crud.create_owner_reading(db=db, reading=reading, owner_id=owner_id)
|
||||
|
||||
|
||||
@app.get('/readings/', response_model=List[schemas.Reading])
|
||||
def read_readings(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
def read_readings(skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)):
|
||||
readings = crud.get_readings(db, skip=skip, limit=limit)
|
||||
return readings
|
||||
|
||||
|
|
|
@ -2,4 +2,5 @@ uvicorn~=0.15.0
|
|||
fastapi~=0.70.0
|
||||
python-dotenv~=0.19.2
|
||||
SQLAlchemy~=1.4.27
|
||||
pydantic~=1.8.2
|
||||
pydantic~=1.8.2
|
||||
passlib~=1.7.4
|
|
@ -0,0 +1,87 @@
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import config
|
||||
from .schemas import TokenData
|
||||
from db import models, schemas
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
|
||||
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def get_user(db: Session, username: str):
|
||||
user = db.query(models.User).filter(models.User.username == username).first()
|
||||
if user:
|
||||
return schemas.UserIn(
|
||||
**{'username': user.username, 'hashed_password': user.hashed_password}
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str):
|
||||
user = get_user(db, username)
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return False
|
||||
return user
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({'exp': expire})
|
||||
encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def get_token(db: Session, form_data):
|
||||
user = authenticate_user(db, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Incorrect username or password',
|
||||
headers={'WWW-Authenticate': 'Bearer'}
|
||||
)
|
||||
access_token_expires = timedelta(days=config.TOKEN_EXPIRE_DAYS)
|
||||
access_token = create_access_token(
|
||||
data={'sub': user.username},
|
||||
expires_delta=access_token_expires
|
||||
)
|
||||
return {'access_token': access_token, 'token_type': 'bearer'}
|
||||
|
||||
|
||||
def get_current_user(db: Session, token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Could not validate credentials.',
|
||||
headers={'WWW-Authenticate': 'Bearer'}
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
|
||||
username: str = payload.get('sub')
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
user = get_user(db, username=token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
|
@ -0,0 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: Optional[str] = None
|
Loading…
Reference in New Issue