1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
use crate::conversion::ByteConvertible;
use crate::repository::daoimpl::InMemoryUserDao;
use log::{debug, error, info, trace};
use tokio::sync::Mutex;
use tonic::{Request, Response, Status};
use uuid::Uuid;

use crate::{
    chaum_pedersen::{ChaumPedersen, GroupParams},
    repository::{dao::UserDao, models::User, session::update_session},
};

// Protobuf generated module
pub mod zkp_auth {
    tonic::include_proto!("zkp_auth");
}

// Protobuf imports
use zkp_auth::{
    auth_server::Auth, AuthenticationAnswerRequest, AuthenticationAnswerResponse,
    AuthenticationChallengeRequest, AuthenticationChallengeResponse, RegisterRequest,
    RegisterResponse,
};

/// A struct representing the zero-knowledge authentication service.
/// It supports different types of Chaum-Pedersen protocols.
///
/// # Type Parameters
///
/// * `C`: Represents the type of Chaum-Pedersen protocol.
/// * `T`: The type used for group elements.
/// * `S`: The type used for scalar values.
pub struct ZkAuth<C, T, S> {
    params: GroupParams<T>,
    dao: Mutex<Box<dyn UserDao<T, S> + Send + Sync>>,
    _type_phantom: std::marker::PhantomData<C>,
    _scalar_phantom: std::marker::PhantomData<S>,
}

impl<
        C,
        T: std::marker::Send + std::marker::Sync + std::clone::Clone + ByteConvertible<T> + 'static,
        S: std::marker::Send + std::marker::Sync + std::clone::Clone + ByteConvertible<S> + 'static,
    > ZkAuth<C, T, S>
{
    pub fn new(params: GroupParams<T>) -> Self {
        let dao = Mutex::new(
            Box::new(InMemoryUserDao::<T, S>::new()) as Box<dyn UserDao<T, S> + Send + Sync>
        );
        Self {
            params,
            dao,
            _type_phantom: std::marker::PhantomData,
            _scalar_phantom: std::marker::PhantomData,
        }
    }
}

/// Implementation of the `Auth` trait for `ZkAuth`.
///
/// This implementation provides the necessary methods for user registration,
/// creating authentication challenges, and verifying authentication answers.
/// It uses generic parameters `C`, `T`, and `S` to work with different cryptographic protocols and data types.
///
/// `C` represents a specific Chaum-Pedersen protocol implementation.
/// `T` is the type for group parameters and public information.
/// `S` is the scalar type used for cryptographic operations.
#[tonic::async_trait]
impl<C, T, S> Auth for ZkAuth<C, T, S>
where
    T: Send + Sync + 'static + Clone + ByteConvertible<T>,
    S: Send + Sync + 'static + Clone + ByteConvertible<S>,
    C: ChaumPedersen<
            Response = S,
            CommitmentRandom = S,
            Challenge = S,
            Secret = S,
            GroupParameters = GroupParams<T>,
            CommitParameters = (T, T, T, T),
        >
        + 'static
        + std::marker::Sync
        + std::marker::Send,
{
    // Register a user with provided credentials.
    // This method accepts a `RegisterRequest` and returns a `RegisterResponse`.
    //
    // # Arguments
    // * `request` - A `Request<RegisterRequest>` containing the user's registration information.
    //
    // # Returns
    // A `Result` containing a `Response<RegisterResponse>` on success, or a `Status` error on failure.
    async fn register(
        &self, request: Request<RegisterRequest>,
    ) -> Result<Response<RegisterResponse>, Status> {
        trace!("register: {:?}", request);
        let req = request.into_inner();

        let y1 =
            T::convert_from(&req.y1).or_else(|_| Err(Status::invalid_argument("Invalid y1")))?;
        let y2 =
            T::convert_from(&req.y2).or_else(|_| Err(Status::invalid_argument("Invalid y2")))?;

        let user = User {
            username: req.user.clone(),
            y1,
            y2,
            r1: None,
            r2: None,
        };

        let mut dao = self.dao.lock().await;
        dao.create(user);

        let reply = RegisterResponse {};
        trace!("register reply: {:?}", reply);
        Ok(Response::new(reply))
    }

    // Create an authentication challenge for a user.
    // This method accepts an `AuthenticationChallengeRequest` and returns an `AuthenticationChallengeResponse`.
    //
    // # Arguments
    // * `request` - A `Request<AuthenticationChallengeRequest>` containing the user's information.
    //
    // # Returns
    // A `Result` containing a `Response<AuthenticationChallengeResponse>` on success, or a `Status` error on failure.
    async fn create_authentication_challenge(
        &self, request: Request<AuthenticationChallengeRequest>,
    ) -> Result<Response<AuthenticationChallengeResponse>, Status> {
        trace!("create_authentication_challenge request: {:?}", request);
        let req = request.into_inner();
        let challenge = C::challenge(&self.params);

        let user = {
            let mut dao = self.dao.lock().await;
            let mut user = dao
                .read(&req.user)
                .ok_or_else(|| Status::not_found("User not found"))?;
            user.r1 = Some(
                T::convert_from(&req.r1)
                    .or_else(|_| Err(Status::invalid_argument("Invalid r1")))?,
            );
            user.r2 = Some(
                T::convert_from(&req.r2)
                    .or_else(|_| Err(Status::invalid_argument("Invalid r2")))?,
            );
            user.clone()
        };

        let auth_id = {
            let mut dao = self.dao.lock().await;
            dao.update(&user.username, user.clone());
            dao.create_auth_challenge(&req.user, &challenge)
        };

        let reply = AuthenticationChallengeResponse {
            auth_id,
            c: S::convert_to(&challenge),
        };
        trace!("create_authentication_challenge reply: {:?}", reply);
        Ok(Response::new(reply))
    }

    // Verify an authentication challenge answer from a user.
    // This method accepts an `AuthenticationAnswerRequest` and returns an `AuthenticationAnswerResponse`.
    //
    // # Arguments
    // * `request` - A `Request<AuthenticationAnswerRequest>` containing the user's authentication answer.
    //
    // # Returns
    // A `Result` containing a `Response<AuthenticationAnswerResponse>` on success, or a `Status` error on failure.
    async fn verify_authentication(
        &self, request: Request<AuthenticationAnswerRequest>,
    ) -> Result<Response<AuthenticationAnswerResponse>, Status> {
        trace!("verify_authentication: {:?}", request);
        let req = request.into_inner();

        let challenge = {
            let mut dao = self.dao.lock().await;
            dao.get_authentication_challenge(&req.auth_id)
                .ok_or_else(|| Status::not_found("Challenge not found"))?
        };

        let user = {
            let mut dao = self.dao.lock().await;
            dao.read(&challenge.user)
                .ok_or_else(|| Status::not_found("User not found"))?
        };

        let s = S::convert_from(&req.s).or_else(|_| Err(Status::invalid_argument("Invalid s")))?;
        let params = self.params.clone();
        let verified = C::verify(
            &params,
            &s,
            &challenge.c,
            &(user.y1, user.y2, user.r1.unwrap(), user.r2.unwrap()),
        );

        debug!("User: {} verified", user.username);
        if !verified {
            error!("Invalid authentication for user: {}", user.username);
            return Err(Status::invalid_argument("Invalid authentication"));
        }
        let session_id = Uuid::new_v4().to_string();
        update_session(user.username.clone(), session_id.clone()); // Clone session_id before moving it
        let reply = AuthenticationAnswerResponse { session_id };

        let mut dao = self.dao.lock().await;
        dao.delete_auth_challenge(&req.auth_id);

        info!("🔑 User: {} authenticated, session id: {}", user.username, req.auth_id);
        trace!("verify_authentication reply: {:?}", reply);
        Ok(Response::new(reply))
    }
}